Beispiel #1
0
    def _init(self):
        self.aircraft_dict = {}

        self.a2a_list = []
        self.target_list = []
        self.blue_list = []
        self.blue_dic = {}

        self.attacking_targets = {}

        self.awacs_team_id = -1
        self.disturb_team_id = -1

        self.agent_state = 0
        self.disturb_state = RedAgentState.AIR_DISTURB1
        self.area_hurt_a = RedAgentState.AREA_HUNT11
        self.area_hurt_b = RedAgentState.AREA_HUNT11
        self.area_hurt_c = RedAgentState.AREA_HUNT11
        self.area_hurt_d = RedAgentState.AREA_HUNT11
        self.air_attack_time = 0
        self.a2g_ha = 0
        self.a2g_hb = 0
        self.team_id_dic = {}
        self.Task = None
        self.dqn = DQN()
        self.reward = 0
        self.done = False
        ### 先随意初始化一个状态向量
        self.last_state = [10000]*config.a2a_LX11 + [6]*config.a2a_LX11 + \
                          [10000] * config.a2a_LX11 + [1] * config.a2a_LX11
Beispiel #2
0
def main(config_path, env_name, train_mode=True, weights_path=None):
    """Load the environment, create an agent, and train it.
    """
    config = cutils.get_config(config_path)
    env = cutils.load_environment(env_name)
    action_size = env.action_space.n
    state_size = env.observation_space.shape

    memory = ReplayMem(buffer=config['exp_replay']['buffer'])
    av_model = SimpleNN(input_shape=state_size[0], output_shape=action_size)
    policy = EpsGreedy(eps=config['train']['eps_start'],
                       decay=config['train']['eps_decay'],
                       eps_end=config['train']['eps_end'])

    agent = DQN(config,
                seed=0,
                ob_space=state_size[0],
                ac_space=action_size,
                av_model=av_model,
                memory=memory,
                policy=policy)

    if weights_path is not None:
        agent.load(weights_path)

    game_logger = GameLogger(100,
                             10)  # TODO Add winning threshold to arguments
    player = Player(agent=agent,
                    env=env,
                    config=config,
                    game_logger=game_logger,
                    train_mode=train_mode)
    player.play()

    return player.glogger.scores
Beispiel #3
0
    def __init__(self, action_set, hParam):

        h, w = 84, 84
        self.qNetwork = DQN(h, w, len(action_set))
        self.targetNetwork = DQN(h, w, len(action_set))
        self.targetNetwork.load_state_dict(self.qNetwork.state_dict())

        self.optimizer = optim.Adam(self.qNetwork.parameters(),
                                    lr=1e-4)
        self.loss_func = nn.MSELoss()

        self.memory = ReplayMemory(hParam['BUFFER_SIZE']) #

        self.DISCOUNT_FACTOR = hParam['DISCOUNT_FACTOR'] # 0.99 

        self.steps_done = 0
        self.EPS_START = hParam['EPS_START'] # 1.0
        self.EPS_END = hParam['EPS_END']
        self.EPS_ITER = 1000000
        self.MAX_ITER = hParam['MAX_ITER']
        self.eps_threshold = self.EPS_START
        self.BATCH_SIZE = hParam['BATCH_SIZE']

        self.n_actions = len(action_set) # 2

        self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        self.qNetwork.to(self.device)
        self.targetNetwork.to(self.device)
        self.qNetwork.train()
Beispiel #4
0
    def predict(self, transaction, debug=False):
        if debug:
            return random.choice(CoinAgent.actions), [0, 0, 0]

        main_dqn = DQN(Paths.MODEL, len(self.actions))
        action_dist = main_dqn.predict([transaction])[0]
        action = CoinAgent.Action(np.argmax(action_dist))
        return action, action_dist
Beispiel #5
0
    def test_generate_inputs(self):
        dqn = DQN(Paths.MODEL, len(CoinAgent.actions))
        transactions = CoinTransaction.get_transactions(Paths.DATA,
                                                        'eth',
                                                        max_size=5)

        for result in dqn.generate_input(transactions, []):
            self.assertIsNotNone(result)
Beispiel #6
0
    def evaluate(self, transactions):
        main_dqn = DQN(Paths.MODEL, len(self.actions))

        action_dists = main_dqn.predict(transactions)
        action_max_indices = np.argmax(action_dists, axis=1)
        actions = [CoinAgent.Action(index) for index in action_max_indices]

        portfolios, _ = self.__get_rewards(transactions, actions)
        print('Portfolio: {:>12,.2f}, {}, {}'.format(
            portfolios[-1], Counter(actions),
            Counter([tuple(x) for x in action_dists]).most_common(2)))

        return portfolios
Beispiel #7
0
    def train(self, transactions, params):
        r = params['r']
        epoch = params['epoch']

        main_dqn = DQN(Paths.MODEL, len(self.actions))
        copied = self.__copy_model(Paths.MODEL)
        target_dqn = DQN(copied, len(self.actions))

        for n in range(epoch):
            action_dists = main_dqn.predict(transactions)
            action_max_indices = np.argmax(action_dists, axis=1)
            actions = [CoinAgent.Action(index) for index in action_max_indices]

            next_action_dists = target_dqn.predict(transactions[1:])
            next_action_max_values = np.max(next_action_dists, axis=1)

            portfolios, rewards = self.__get_rewards(transactions, actions)

            target_dists = []
            for i in range(len(transactions)):
                target_dist = list(action_dists[i])
                target_dist[action_max_indices[i]] = rewards[i]
                if i < len(transactions) - 1:
                    target_dist[
                        action_max_indices[i]] += r * next_action_max_values[i]
                target_dists.append(target_dist)

            result = main_dqn.train(transactions, target_dists)
            print('[{}] #{}, Loss: {:>8,.4f}, Portfolio: {:>12,.2f}, {}, {}'.
                  format(
                      n, result['global_step'], result['loss'], portfolios[-1],
                      Counter(actions),
                      Counter([tuple(x)
                               for x in action_dists]).most_common(2)))

            if n > 0 and n % 10 == 0:
                copied = self.__copy_model(Paths.MODEL)
                target_dqn = DQN(copied, len(self.actions))
Beispiel #8
0
}

env = gym.make("CartPole-v0")
env = CartPole_Pixel(
    env, True, False, reference_domain,
    colors)  # Randomize = True, regularize = False for testing
env = FrameStack(env, 3)
agent = DQN(env,
            CNN_cartpole,
            replay_start_size=1000,
            replay_buffer_size=100000,
            gamma=0.99,
            update_target_frequency=1000,
            minibatch_size=32,
            learning_rate=1e-4,
            initial_exploration_rate=1.,
            final_exploration_rate=0.01,
            final_exploration_step=10000,
            adam_epsilon=1e-4,
            logging=True,
            loss='mse',
            lam=0,
            regularize=False,
            add_current_step=ADD_CURRENT_STEP)

scores_all = np.zeros((3, NUM_SEEDS, len(colors)))  # 3 agents

for name_id, name in enumerate(agents.keys()):
    for se in range(NUM_SEEDS):
        print('seed : ', se)
        agent.load(agents[name][se])
Beispiel #9
0
from agent.dqn import DQN
from hyperparameters import Config

config = Config()

config.name = 'CartPole-v0'
# 初始化环境
env = gym.make(config.name)
env.seed(1)
env = env.unwrapped

# 初始化agent
agent = DQN(config=config,
            env=env,
            doubleDQN=False,
            duelingDQN=True,
            NoisyDQN=False,
            N_stepDQN=False,
            Prioritized=False)

# agent.load()

iteration = 0  # 总步数

# 记录时间
start_time = time.time()
ep_reward_list = []  # 存放每回合的reward
mean_ep_reward_list = []  # 整个训练过程的平均reward

# 循环nepisode个episode
for e in range(config.episode):
    print('seed : ', se)
    env = gym.make("CartPole-v0")
    env = CartPole_Pixel(env, RANDOMIZE, REGULARIZE, reference_domain, colors)
    env = FrameStack(env, 3)
    env.metadata['_max_episode_steps'] = 200  #useless
    env.reset()

    agent = DQN(env,
                CNN_cartpole,
                replay_start_size=1000,
                replay_buffer_size=100000,
                gamma=0.99,
                update_target_frequency=1000,
                minibatch_size=32,
                learning_rate=1e-4,
                initial_exploration_rate=1.,
                final_exploration_rate=0.01,
                final_exploration_step=10000,
                adam_epsilon=1e-4,
                logging=True,
                loss='mse',
                log_folder='results/' + str(name) + '_' + str(se),
                lam=LAMBDA,
                regularize=REGULARIZE,
                add_current_step=ADD_CURRENT_STEP)

    agent.learn(timesteps=nb_steps, verbose=True)
    agent.save()

    env.close()
Beispiel #11
0
class RedDQNAgent(Agent):
    def __init__(self, name, config, **kwargs):
        super().__init__(name, config['side'])

        self._init()

    def _init(self):
        self.aircraft_dict = {}

        self.a2a_list = []
        self.target_list = []
        self.blue_list = []
        self.blue_dic = {}

        self.attacking_targets = {}

        self.awacs_team_id = -1
        self.disturb_team_id = -1

        self.agent_state = 0
        self.disturb_state = RedAgentState.AIR_DISTURB1
        self.area_hurt_a = RedAgentState.AREA_HUNT11
        self.area_hurt_b = RedAgentState.AREA_HUNT11
        self.area_hurt_c = RedAgentState.AREA_HUNT11
        self.area_hurt_d = RedAgentState.AREA_HUNT11
        self.air_attack_time = 0
        self.a2g_ha = 0
        self.a2g_hb = 0
        self.team_id_dic = {}
        self.Task = None
        self.dqn = DQN()
        self.reward = 0
        self.done = False
        ### 先随意初始化一个状态向量
        self.last_state = [10000]*config.a2a_LX11 + [6]*config.a2a_LX11 + \
                          [10000] * config.a2a_LX11 + [1] * config.a2a_LX11

    def reset(self):
        self._init()

    def step(self, sim_time, obs_red, **kwargs):

        self._parse_teams(obs_red)

        cmd_list = []
        self.done = False
        self._parse_observation(obs_red)
        # print('红方情报:',obs_red['qb'])
        '''第一波次'''
        # 护卫舰初始化
        if self.agent_state == 0:
            index = 1
            for ship in obs_red['units']:
                if ship['LX'] == 21:
                    if index == 1:
                        cmd_list.extend(
                            self._ship_movedeploy(ship['ID'], SHIP_POINT1))
                        print('1号护卫舰就位')
                        index += 1
                        continue
                    if index == 2:
                        cmd_list.extend(
                            self._ship_movedeploy(ship['ID'], SHIP_POINT2))
                        print('2号护卫舰就位')
                        index += 1
                        continue
                    if index == 3:
                        cmd_list.extend(
                            self._ship_movedeploy(ship['ID'], SHIP_POINT3))
                        print('3号护卫舰就位')
                        index += 1
                        continue
            self.agent_state = 1

        # 预警机1架--YA + 护航歼击机2机编队--JA
        if self.agent_state == 1:
            for awas in obs_red['units']:
                if awas['LX'] == 12:
                    cmd_list.extend(
                        self._awacs_patrol(awas['ID'], AWACS_PATROL_POINT,
                                           AWACS_PATROL_PARAMS))
                    print('预警机巡逻')
                    self.agent_state = 5
        if self.agent_state == 5:
            if 'YA' in list(self.team_id_dic.keys()):
                cmd_list.extend(self._awacs_escort(self.team_id_dic['YA']))
                print('给预警机护航')
                self.agent_state = 6

        # 干扰机3架(一起使用作360°干扰)--RA + 护航歼击机2机编队--JB
        # 正式发布版本会对模型进行聚合, 即只提供一架干扰机(所以这里作了修改)
        if self.agent_state == 6:
            cmd_list.extend(
                self._takeoff_areapatrol(1, 13, AIR_DISTURB_POINT1,
                                         DISTURB_PATROL_PARAMS))
            print('干扰机起飞')
            self.agent_state = 7
        if self.agent_state == 7:
            if 'RA' in list(self.team_id_dic.keys()):
                cmd_list.extend(self._disturb_escort(self.team_id_dic['RA']))
                print('给干扰机护航')
                self.agent_state = 8

        # 干扰机到达待定区域后下达区域干扰指令
        # 在待定区开启航线干扰
        if self.disturb_state == RedAgentState.AIR_DISTURB1:
            for disturb in obs_red['units']:
                if disturb['LX'] == 13 and 15000 < disturb[
                        'X'] < 75000 and -20000 < disturb['Y'] < 20000:
                    cmd_list.extend(
                        self._disturb_linepatrol(self.team_id_dic['RA'],
                                                 NORTH_LINE))
                    self.disturb_state = RedAgentState.AIR_DISTURB2
                    print('在待定区开启航线干扰')
        # 向北部指定区域行进干扰
        if self.disturb_state == RedAgentState.AIR_DISTURB2:
            for disturb in obs_red['units']:
                if disturb['LX'] == 13 and -50000 < disturb[
                        'X'] < -40000 and 50000 < disturb['Y'] < 60000:
                    cmd_list.extend(
                        self._disturb_patrol(self.team_id_dic['RA'],
                                             AIR_DISTURB_POINT2,
                                             DISTURB_PATROL_PARAMS))
                    self.disturb_state = RedAgentState.AIR_DISTURB3
                    print('在待定区开启区域干扰')
        # 干扰机南下干扰
        ship_flag = True
        if self.disturb_state == RedAgentState.AIR_DISTURB3:
            for blue_unit in obs_red['qb']:
                if blue_unit['LX'] == 21:
                    ship_flag = False
            # if ship_flag:
            if sim_time > 3000:
                cmd_list.extend(
                    self._disturb_patrol(self.team_id_dic['RA'],
                                         AIR_DISTURB_POINT4,
                                         DISTURB_PATROL_PARAMS))
                print('干扰机南下干扰')
                self.disturb_state = RedAgentState.AIR_DISTURB4

        # 向南部指定区域行进干扰
        if self.disturb_state == RedAgentState.AIR_DISTURB4:
            for disturb in obs_red['units']:
                if disturb['LX'] == 13 and -50000 < disturb[
                        'X'] < -40000 and -60000 < disturb['Y'] < -50000:
                    cmd_list.extend(
                        self._disturb_patrol(self.team_id_dic['RA'],
                                             AIR_DISTURB_POINT4,
                                             DISTURB_PATROL_PARAMS))
                    self.disturb_state = RedAgentState.AIR_DISTURB5

        # 轰炸机2机编队--HA + 护航歼击机2机编队--JC
        if self.agent_state == 15:
            cmd_list.extend(
                self._takeoff_areapatrol(2, 15, AREA_HUNT_POINT1,
                                         AREA_PATROL_PARAMS))
            print('轰炸机HA起飞')
            self.agent_state = 9
        for ship in obs_red['qb']:
            if ship['LX'] == 21 and 'HA' in list(self.team_id_dic.keys(
            )) and self.area_hurt_a == RedAgentState.AREA_HUNT11:
                cmd_list.extend(
                    self._targethunt(self.team_id_dic['HA'], ship['ID']))
                self.area_hurt_a = RedAgentState.AREA_HUNT12
                print('HA进行目标突击,目标为蓝方舰船')
        if self.agent_state == 9:
            if 'HA' in list(self.team_id_dic.keys()):
                cmd_list.extend(self._A2G_escort(self.team_id_dic['HA']))
                print('给轰炸机HA护航')
                self.agent_state = 10
        # JC进入北部警戒阵位1
        if self.a2g_ha == 0 and 'HA' in list(self.team_id_dic.keys()):
            for a2g in obs_red['units']:
                if a2g['TMID'] == self.team_id_dic['HA'] and -75000 < a2g[
                        'X'] < -35000 and 50000 < a2g['Y'] < 80000:
                    for a2a in obs_red['units']:
                        if a2a['TMID'] == self.team_id_dic['JC']:
                            cmd_list.extend(
                                self._areapatrol(a2a['ID'], AIR_PATROL_POINT1,
                                                 AIR_PATROL_PARAMS_0))
                            print('JC进入北部警戒阵位1')
                    self.a2g_ha = 1

        # 轰炸机6机编队--HB + 护航歼击机2机编队--JD
        if self.agent_state == 10:
            cmd_list.extend(self._takeoff_areahunt(2, AREA_HUNT_POINT0))
            cmd_list.extend(self._takeoff_areahunt(4, AREA_HUNT_POINT0))
            print('轰炸机HB起飞')
            self.agent_state = 11

        if self.agent_state == 11:
            if 'HB' in list(self.team_id_dic.keys()):
                # cmd_list.extend(self._targethunt(self.team_id_dic['HB'], 5011))

                cmd_list.extend(self._A2G_escort(self.team_id_dic['HB']))
                print('给轰炸机HB护航')
                self.agent_state = 12
        # JD进入北部警戒阵位2
        if self.a2g_hb == 0 and 'HB' in list(self.team_id_dic.keys()):
            for a2g in obs_red['units']:
                if a2g['TMID'] == self.team_id_dic['HB'] and -75000 < a2g[
                        'X'] < -35000 and 50000 < a2g['Y'] < 80000:
                    for a2a in obs_red['units']:
                        if a2a['TMID'] == self.team_id_dic['JD']:
                            cmd_list.extend(
                                self._areapatrol(a2a['ID'], AIR_PATROL_POINT2,
                                                 AIR_PATROL_PARAMS_0))
                    self.a2g_hb = 1

        # 阻援歼击机2机编队--JE
        if self.agent_state == 8:
            cmd_list.extend(
                self._takeoff_areapatrol(2, 11, AIR_PATROL_POINT3,
                                         AIR_PATROL_PARAMS))
            print('阻援歼击机JE起飞')
            self.agent_state = 13
        # 歼击机JG, JH-->中部阻援阵位(第二阶段从中部南下警戒)
        if self.agent_state == 13:
            cmd_list.extend(
                self._takeoff_areapatrol(2, 11, [-105001, 0, 8000],
                                         AIR_PATROL_PARAMS))
            print('阻援歼击机JG起飞')
            self.agent_state = 14
        if self.agent_state == 14:
            cmd_list.extend(
                self._takeoff_areapatrol(2, 11, [-105002, 0, 8000],
                                         AIR_PATROL_PARAMS))
            print('阻援歼击机JH起飞')
            # print('阻援歼击机JI起飞')
            self.agent_state = 15
        '''第二波次'''

        # 如果干扰机存活则突击南部,如果干扰机不在,则突击北部
        hc = 0
        if self.agent_state == 12 and sim_time > 2000:
            for disturb in obs_red['units']:
                if disturb['LX'] == 13:
                    hc = 1
                    break
            if hc == 1:
                cmd_list.extend(self._takeoff_areahunt(2, AREA_HUNT_POINT4_0))
                print('轰炸机HC起飞突击南部')
                self.agent_state = 16
            else:
                cmd_list.extend(self._takeoff_areahunt(2, AREA_HUNT_POINT0))
                print('轰炸机HC起飞突击北部')
                self.agent_state = 16
        if self.agent_state == 16:
            if 'HC' in list(self.team_id_dic.keys()):
                cmd_list.extend(self._A2G_escort(self.team_id_dic['HC']))
                print('给轰炸机HC护航')
                self.agent_state = 17
        '''第三波次兵力列表(轰炸机6架, 歼击机4架)'''
        # 空战歼击机2机编队--JK
        if self.agent_state == 17:
            cmd_list.extend(
                self._takeoff_areapatrol(2, 11, AIR_PATROL_POINT7,
                                         AIR_PATROL_PARAMS))
            print('阻援歼击机JK起飞')
            self.agent_state = 18
        # 轰炸机6机编队--HD + 护航歼击机2机编队--JL
        if self.agent_state == 18:
            cmd_list.extend(self._takeoff_areahunt(2, AREA_HUNT_POINT4_0))
            cmd_list.extend(self._takeoff_areahunt(4, AREA_HUNT_POINT4_0))
            print('轰炸机HD起飞')
            self.agent_state = 19

        if self.agent_state == 19:
            if 'HD' in list(self.team_id_dic.keys()):
                cmd_list.extend(self._A2G_escort(self.team_id_dic['HD']))
                # cmd_list.extend(self._targethunt(self.team_id_dic['HD'], 5010))

                print('给轰炸机HD护航')
                self.agent_state = 20

        # 拦截
        blue_lx_list = []
        for blue_unit in obs_red['qb']:
            blue_lx_list.append(blue_unit['LX'])

            # 获取蓝方单位并且是存活状态

            if blue_unit['LX'] == 11 or blue_unit['LX'] == 15 or blue_unit[
                    'LX'] == 12:
                if blue_unit['WH'] == 1 and blue_unit[
                        'ID'] not in self.blue_list:

                    distances, fuel, weapons = [], [], []
                    a2a_units = []
                    for a2a in obs_red['units']:
                        if a2a['LX'] == 11 and a2a[
                                'Fuel'] > 3000 and '170' in list(
                                    a2a['WP'].keys()) and int(
                                        a2a['WP']['170']) > 0:

                            distance = math.sqrt(
                                math.pow(a2a['X'] - blue_unit['X'], 2) +
                                math.pow(a2a['Y'] - blue_unit['Y'], 2))
                            distances.append(distance)
                            weapons.append(int(a2a['WP']['170']))
                            fuel.append(a2a["Fuel"])
                            a2a_units.append(a2a)

                    if len(a2a_units) != 0:
                        live = len(distances)
                        dead = config.a2a_LX11 - live
                        distances = np.array(distances + [0] * dead)
                        fuel = np.array(fuel + [0] * dead)
                        weapons = np.array(weapons + [0] * dead)
                        dead_mask = np.array([1] * live + [0] * dead)
                        curr_state = np.hstack(
                            (distances, fuel, weapons, dead_mask))

                        who_index = self.dqn._choose_action(
                            distances, fuel, weapons, dead_mask)
                        print("******************a2a_units length, index: ",
                              len(a2a_units), who_index)
                        a2a = a2a_units[who_index]
                        self.done = True

                        cmd_list.extend(
                            self._airattack(a2a['ID'], blue_unit['ID']))
                        print("***************************执行战斗任务!!!!",
                              a2a['ID'], "  ATTACK   ", blue_unit['ID'])
                        self.blue_list.append(blue_unit['ID'])
                        self.blue_dic[a2a['ID']] = blue_unit['ID']
        # 红方将蓝方单位击落或者红方拦截飞机被蓝方击落
        blue = 0
        del_blue = False
        del_blue2 = False
        for blue_target in self.blue_list:
            for blue_unit in obs_red['qb']:
                if blue_target == blue_unit['ID']:
                    blue = 1
                    break
            # 这里是判断我方位飞机是否存活
            for a2a_id in list(self.blue_dic.keys()):
                for a2a in obs_red['units']:
                    if a2a['ID'] == a2a_id:
                        del_blue = True
                        break
                if del_blue is False:
                    self.reward = -5
                    self.blue_dic.pop(a2a_id)
            for a2a_id in list(self.blue_dic.keys()):
                if blue_target == self.blue_dic[a2a_id]:
                    del_blue2 = True
                    break
            if blue == 0 or del_blue is False or del_blue2 is False:
                self.blue_list.remove(blue_target)
                # 需根据红方飞机当前状态重新下指令
                for a2a in obs_red['units']:
                    # 此时对状态为15 或 13 的蓝方飞机进行判断
                    if a2a['LX'] == 11 and a2a['ID'] in list(
                            self.blue_dic.keys()) and self.blue_dic[
                                a2a['ID']] == blue_target:
                        self.reward = 5
                        if a2a['ST'] == 15 or a2a['ST'] == 13:
                            # 如果油量小于4000或者子弹数量为0则返航,否者去预定区域进行区域巡逻
                            if a2a['Fuel'] < 4000 or int(
                                    a2a['WP']['170']) == 0:
                                cmd_list.extend(self._returntobase(a2a['ID']))
                            else:
                                for Tid in list(self.team_id_dic.keys()):
                                    if a2a['TMID'] == self.team_id_dic[Tid]:
                                        if Tid == 'JA':
                                            cmd_list.extend(
                                                self._areapatrol(
                                                    a2a['ID'],
                                                    AWACS_PATROL_POINT,
                                                    AIR_PATROL_PARAMS_0))
                                        if Tid == 'JB':
                                            cmd_list.extend(
                                                self._areapatrol(
                                                    a2a['ID'],
                                                    AIR_DISTURB_POINT2,
                                                    AIR_PATROL_PARAMS_0))
                                        if Tid == 'JC':
                                            cmd_list.extend(
                                                self._areapatrol(
                                                    a2a['ID'],
                                                    AIR_PATROL_POINT1,
                                                    AIR_PATROL_PARAMS_0))
                                        if Tid == 'JD':
                                            cmd_list.extend(
                                                self._areapatrol(
                                                    a2a['ID'],
                                                    AIR_PATROL_POINT2,
                                                    AIR_PATROL_PARAMS_0))
                                        if Tid == 'JE':
                                            cmd_list.extend(
                                                self._areapatrol(
                                                    a2a['ID'],
                                                    AIR_PATROL_POINT3,
                                                    AIR_PATROL_PARAMS_0))
                                        if Tid == 'JF':
                                            cmd_list.extend(
                                                self._areapatrol(
                                                    a2a['ID'],
                                                    AIR_PATROL_POINT3,
                                                    AIR_PATROL_PARAMS_0))
                                        if Tid == 'JG':
                                            cmd_list.extend(
                                                self._areapatrol(
                                                    a2a['ID'],
                                                    AIR_PATROL_POINT5,
                                                    AIR_PATROL_PARAMS_0))
                                        if Tid == 'JH':
                                            cmd_list.extend(
                                                self._areapatrol(
                                                    a2a['ID'],
                                                    AIR_PATROL_POINT6,
                                                    AIR_PATROL_PARAMS_0))
                                        if Tid == 'JI':
                                            cmd_list.extend(
                                                self._areapatrol(
                                                    a2a['ID'],
                                                    AIR_PATROL_POINT7,
                                                    AIR_PATROL_PARAMS_0))
                                        if Tid == 'JJ':
                                            cmd_list.extend(
                                                self._areapatrol(
                                                    a2a['ID'],
                                                    AIR_PATROL_POINT7,
                                                    AIR_PATROL_PARAMS_0))
                                        if Tid == 'JK':
                                            cmd_list.extend(
                                                self._areapatrol(
                                                    a2a['ID'],
                                                    AIR_PATROL_POINT7,
                                                    AIR_PATROL_PARAMS_0))
                                        if Tid == 'JL':
                                            cmd_list.extend(
                                                self._areapatrol(
                                                    a2a['ID'],
                                                    AIR_PATROL_POINT4,
                                                    AIR_PATROL_PARAMS_0))
        if self.done:
            self.dqn._store_transition(self.last_state, who_index, self.reward,
                                       curr_state)
            self.dqn._learning()
            self.last_state = curr_state
        self.done = False
        return cmd_list

    def _parse_observation(self, obs_red):
        self._parse_teams(obs_red)

    # 获取编队ID
    def _parse_teams(self, red_dict):
        for team in red_dict['teams']:
            if team['Task']:
                self.Task = json.loads(team['Task'])
                # print('self.Task:',self.Task)

            # 预警机编队
            if team['LX'] == UnitType.AWACS:
                self.team_id_dic['YA'] = team['TMID']

            # 干扰机编队
            elif team['LX'] == UnitType.DISTURB:
                if 'fly_num' in list(self.Task.keys()):
                    if self.Task['point_x'] == 45000 and self.Task[
                            'point_y'] == 0:
                        self.team_id_dic['RA'] = team['TMID']

            # 轰炸机编队
            elif team['LX'] == UnitType.A2G:
                if 'fly_num' in list(self.Task.keys()):
                    # HA
                    if self.Task['fly_num'] == 2 and self.Task[
                            'point_x'] == -55000 and self.Task[
                                'point_y'] == 65000:
                        self.team_id_dic['HA'] = team['TMID']
                    # HB
                    elif self.Task['fly_num'] == 4 and self.Task[
                            'point_x'] == -129533.05624 and self.Task[
                                'point_y'] == 87664.0398:
                        self.team_id_dic['HB'] = team['TMID']
                    # HC
                    elif self.Task['fly_num'] == 2 and self.Task[
                            'point_x'] == -131156.63859 and self.Task[
                                'point_y'] == -87887.86736:
                        self.team_id_dic['HC'] = team['TMID']
                    # HD
                    elif self.Task['fly_num'] == 4 and self.Task[
                            'point_x'] == -131156.63859 and self.Task[
                                'point_y'] == -87887.86736:
                        self.team_id_dic['HD'] = team['TMID']

            # 歼击机编队
            elif team['LX'] == UnitType.A2A:
                if self.Task['maintype'] == 'takeoffprotect':
                    # JA
                    if self.Task['cov_id'] == self.team_id_dic['YA']:
                        self.team_id_dic['JA'] = team['TMID']
                    # JB
                    elif self.Task['cov_id'] == self.team_id_dic['RA']:
                        self.team_id_dic['JB'] = team['TMID']
                    # JC
                    elif self.Task['cov_id'] == self.team_id_dic['HA']:
                        self.team_id_dic['JC'] = team['TMID']
                    # JD
                    elif self.Task['cov_id'] == self.team_id_dic['HB']:
                        self.team_id_dic['JD'] = team['TMID']
                    # JF
                    elif self.Task['cov_id'] == self.team_id_dic['HC']:
                        self.team_id_dic['JF'] = team['TMID']
                    # JL
                    elif self.Task['cov_id'] == self.team_id_dic['HD']:
                        self.team_id_dic['JL'] = team['TMID']
                else:
                    # JE
                    if 'fly_num' in list(self.Task.keys()):
                        if self.Task['fly_num'] == 2 and self.Task[
                                'point_x'] == -105000 and self.Task[
                                    'point_y'] == 0:
                            self.team_id_dic['JE'] = team['TMID']
                    # JG
                    if 'fly_num' in list(self.Task.keys()):
                        if self.Task['fly_num'] == 2 and self.Task[
                                'point_x'] == -105001 and self.Task[
                                    'point_y'] == 0:
                            self.team_id_dic['JG'] = team['TMID']
                    # JH
                    if 'fly_num' in list(self.Task.keys()):
                        if self.Task['fly_num'] == 2 and self.Task[
                                'point_x'] == -105002 and self.Task[
                                    'point_y'] == 0:
                            self.team_id_dic['JH'] = team['TMID']
                    # JI
                    if 'fly_num' in list(self.Task.keys()):
                        if self.Task['fly_num'] == 2 and self.Task[
                                'point_x'] == -45000 and self.Task[
                                    'point_y'] == -55000:
                            self.team_id_dic['JI'] = team['TMID']
                    # JJ
                    if 'fly_num' in list(self.Task.keys()):
                        if self.Task['fly_num'] == 2 and self.Task[
                                'point_x'] == -45001 and self.Task[
                                    'point_y'] == -55000:
                            self.team_id_dic['JJ'] = team['TMID']
                    # JK
                    if 'fly_num' in list(self.Task.keys()):
                        if self.Task['fly_num'] == 2 and self.Task[
                                'point_x'] == -95000 and self.Task[
                                    'point_y'] == -85000:
                            self.team_id_dic['JK'] = team['TMID']

    # 无人机出击
    @staticmethod
    def _uav_areapatrol(uav_id, uav_point, uav_params):
        return [EnvCmd.make_uav_areapatrol(uav_id, *uav_point, *uav_params)]

    # 预警机出击
    @staticmethod
    def _awacs_patrol(self_id, AWACS_PATROL_POINT, AWACS_PATROL_PARAMS):
        return [
            EnvCmd.make_awcs_areapatrol(self_id, *AWACS_PATROL_POINT,
                                        *AWACS_PATROL_PARAMS)
        ]

    # 预警机护航
    def _awacs_escort(self, awacs_team_id):
        return [
            EnvCmd.make_takeoff_protect(RED_AIRPORT_ID, 2, awacs_team_id, 0,
                                        100, 250)
        ]

    # 干扰机进行区域干扰
    def _disturb_patrol(self, disturb_team_id, patrol_point, patrol_params):
        return [
            EnvCmd.make_disturb_areapatrol(disturb_team_id, *patrol_point,
                                           *patrol_params)
        ]

    # 干扰机进行航线干扰
    def _disturb_linepatrol(self, self_id, point_list):
        return [
            EnvCmd.make_disturb_linepatrol(self_id, 160, 0, 'line', point_list)
        ]

    # 轰炸机起飞突击
    @staticmethod
    def _takeoff_areahunt(num, area_hunt_point):
        return [
            EnvCmd.make_takeoff_areahunt(RED_AIRPORT_ID, num, 270, 80,
                                         *area_hunt_point,
                                         *[270, 1000, 1000, 160])
        ]

    # 干扰机护航
    def _disturb_escort(self, disturb_team_id):
        return [
            EnvCmd.make_takeoff_protect(RED_AIRPORT_ID, 2, disturb_team_id, 1,
                                        100, 250)
        ]

    # 轰炸机护航
    def _A2G_escort(self, a2g_team_id):
        return [
            EnvCmd.make_takeoff_protect(RED_AIRPORT_ID, 2, a2g_team_id, 1, 100,
                                        250)
        ]

    # 起飞区域巡逻
    @staticmethod
    def _takeoff_areapatrol(num, lx, patrol_point, patrol_params):
        # patrol_params为5个参数
        return [
            EnvCmd.make_takeoff_areapatrol(RED_AIRPORT_ID, num, lx,
                                           *patrol_point, *patrol_params)
        ]

    @staticmethod
    def _airattack(unit_id, target_id):
        return [EnvCmd.make_airattack(unit_id, target_id, 0)]

    # 区域巡逻
    @staticmethod
    def _areapatrol(unit_id, patrol_point, patrol_params):
        # patrol_params为6个参数
        return [EnvCmd.make_areapatrol(unit_id, *patrol_point, *patrol_params)]

    # 返航
    @staticmethod
    def _returntobase(unit_id):
        return [EnvCmd.make_returntobase(unit_id, 30001)]

    # 轰炸机目标突击
    @staticmethod
    def _targethunt(self_id, target_id):
        return [EnvCmd.make_targethunt(self_id, target_id, 270, 80)]

    # 轰炸机区域突击
    @staticmethod
    def _areahunt(self_id, point):
        return [
            EnvCmd.make_areahunt(self_id, 270, 80, *point, *AREA_HUNT_PARAMS)
        ]

    # 护卫舰区域巡逻
    def _ship_areapatrol(self, self_id, point):
        return [
            EnvCmd.make_ship_areapatrol(self_id, *point, *SHIP_PATROL_PARAMS_0)
        ]

    # 护卫舰初始化部署
    def _ship_movedeploy(self, self_id, point):
        return [EnvCmd.make_ship_movedeploy(self_id, *point, 90, 1)]
Beispiel #12
0
writer = SummaryWriter()
current = datetime.today().strftime('%Y%m%d%H%M%S')

plot_episode_rewards = []  # 이건 에피소드 받은 리워드 ( 에이전트 동안 받은 개별 리워드 다 더한 값)
plot_episode_valid_steps = []  # 에피소드별 action 요청이 하나라도 들어온 step 카운트
plot_episode_count_requested_agent = np.asarray(
    [0] * N_AGENTS)  # 에이전트별 요청받은 에이전트 대수 기록
plot_episode_requested_agents = np.asarray([0] * N_AGENTS)
plot_count_per_actions = np.asarray([0] * N_ACTION)
plot_episode_epsilon = []
args = get_common_args()

## change policy as a DQN
args = dqn_args(args)
policy = DQN(args)

agents = Agents(args, policy)
env = gym.make('CartPole-v0')
worker = RolloutWorker(env, agents, args)
buffer = ReplayBuffer(args)

plt.figure()
plt.axis([0, args.n_epoch, 0, 100])
win_rates = []
episode_rewards = []
train_steps = 0

save_path = args.result_dir + '/' + current
os.makedirs(save_path, exist_ok=True)
Beispiel #13
0
class Agent():
    def __init__(self, action_set, hParam):

        h, w = 84, 84
        self.qNetwork = DQN(h, w, len(action_set))
        self.targetNetwork = DQN(h, w, len(action_set))
        self.targetNetwork.load_state_dict(self.qNetwork.state_dict())

        self.optimizer = optim.Adam(self.qNetwork.parameters(),
                                    lr=1e-4)
        self.loss_func = nn.MSELoss()

        self.memory = ReplayMemory(hParam['BUFFER_SIZE']) #

        self.DISCOUNT_FACTOR = hParam['DISCOUNT_FACTOR'] # 0.99 

        self.steps_done = 0
        self.EPS_START = hParam['EPS_START'] # 1.0
        self.EPS_END = hParam['EPS_END']
        self.EPS_ITER = 1000000
        self.MAX_ITER = hParam['MAX_ITER']
        self.eps_threshold = self.EPS_START
        self.BATCH_SIZE = hParam['BATCH_SIZE']

        self.n_actions = len(action_set) # 2

        self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        self.qNetwork.to(self.device)
        self.targetNetwork.to(self.device)
        self.qNetwork.train()

    def updateTargetNet(self):
        self.targetNetwork.load_state_dict(self.qNetwork.state_dict())    

    def getAction(self, state):
        state = torch.from_numpy(state).float() / 255.0
        sample = random.random()
        state = state.to(self.device)

        if sample > self.eps_threshold or self.steps_done > 1000000:
            estimate = self.qNetwork(state).max(1)[1].cpu()
            del state

            return estimate.data[0]
        else:
            return random.randint(0, self.n_actions - 1)

    def updateQnet(self):
        if len(self.memory) < self.BATCH_SIZE:
            return

        transitions = self.memory.sample(self.BATCH_SIZE)
        batch = Transition(*zip(*transitions))

        state_batch = torch.cat(batch.state).to(self.device)
        action_batch = torch.cat(batch.action).to(self.device)
        next_state_batch = torch.cat(batch.next_state).to(self.device)
        reward_batch = torch.cat(batch.reward).to(self.device)
        done_batch = torch.cat(batch.done).to(self.device)

        with torch.no_grad():
            self.targetNetwork.eval()
            next_state_values = self.targetNetwork(next_state_batch)

        y_batch = torch.cat(tuple(reward if done else reward + self.DISCOUNT_FACTOR * torch.max(value) 
                            for reward, done, value in zip(reward_batch, done_batch, next_state_values)))

        state_action_values = torch.sum(self.qNetwork(state_batch) * action_batch, dim=1)

        loss = self.loss_func(state_action_values, y_batch.detach())

        self.optimizer.zero_grad()
        loss.backward()
        # for param in self.qNetwork.parameters():
        #     param.grad.data.clamp_(-1, 1)
        self.optimizer.step()
        self.updateEPS()
        return loss.data

    def updateEPS(self):
        self.steps_done += 1

        if self.EPS_ITER >= self.steps_done:
            self.eps_threshold = self.EPS_END \
                               + ((self.EPS_START - self.EPS_END) \
                                * (self.EPS_ITER - self.steps_done) / self.EPS_ITER)
        else:
            self.eps_threshold=self.EPS_END

        # print('eps: ',self.eps_threshold)

    def save(self, path='checkpoint.pth.tar'):
        print('save')
        torch.save({
            'state_dict': self.qNetwork.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict()
        }, path)

    def load(self, path='checkpoint.pth.tar'):
        print('load:', path)
        checkpoint = torch.load(path)
        self.qNetwork.load_state_dict(checkpoint['state_dict'])
        self.targetNetwork.load_state_dict(checkpoint['state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])