Пример #1
0
    args.max_time = 75

elif args.row_num == 5:
    args.column_num = 5
    args.max_time = 25

elif args.row_num == 3:
    args.column_num = 3
    args.max_time = 3

################## for initialization ###########################
global log_file

log_file = open(args.save_path + 'log.txt', 'w')

animal_density = generate_map(args)
env = Env(args, animal_density, cell_length=None, canvas=None, gui=False)

patrollers = [Patroller_CNN(args, 'pa_model' + str(i)) for i in range(5)]
poachers = [Poacher(args, 'po_model' + str(i)) for i in range(5)]

config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
sess.run(tf.global_variables_initializer())

args.po_location = None
### load the DQN models you have trained
if args.load:
    poachers[0] = Poacher_h(args, animal_density)
    patrollers[0] = Patroller_h(args, animal_density)
Пример #2
0
def main(stdscr):
    # locale.setlocale(locale.LC_ALL, '')
    curses.nonl()

    # constants related to rooms
    room_max_size = 15
    room_min_size = 5
    max_rooms = 15

    # constants related to padding size
    # either height/width has to be larger than their counterparts of map
    # becuase writing on the bottom right corner of the padding causes an error
    height = 151
    width = 150
    # constants related to map size
    map_height = 150
    map_width = 150
    # get size of the screen for positioning
    # FIXME: as of now, we are going to assume that stdscr size doesn't change
    # stdscr is automatically init by wrapper()
    base_height, base_width = stdscr.getmaxyx()
    # constants related to view size
    # TODO: change view size in relation to screen size
    view_width = 100
    view_height = 24

    # default setups
    init_colors()
    curses.curs_set(0)  # hide cursor
    # win has to be a pad, so that scrolling is easily supported
    win = curses.newpad(height, width)
    win.keypad(True)
    win.bkgd(' ')
    # msgwin
    msg_win = curses.newpad(10, 100)
    msgbox = MsgBox(msg_win, view_width, view_height, base_width, base_height)

    # bars
    bar_width = 33
    bar_height = 1
    bar_win = curses.newwin(bar_height, bar_width, 1, 1)
    # bar_win.border()

    combat_module = Combat(hp=30, defense=2, power=5)
    inventory = Inventory(26)
    player = Player(combat_module, inventory)
    entities = [player]
    game_map = maps.GameMap(map_width, map_height)
    maps.generate_map(game_map, max_rooms, room_min_size, room_max_size,
                      player, entities)

    game_state = GameStates.PLAYERS_TURN
    previous_game_state = game_state
    # initial compute of fov
    game_map.compute_fov(player)
    inventory_menu = Menu(win, "INVENTORY", 30, base_width, base_height)
    while True:
        rendering.render_all(win, entities, game_map, view_width, view_height,
                             player, base_width, base_height, msgbox, bar_win,
                             inventory_menu, game_state)
        action = input_handler.handle_input(win, game_state)

        mv = action.get('move')
        pickup = action.get('pickup')
        show_inventory = action.get('show_inventory')
        hide_inventory = action.get('hide_inventory')
        item_at_cursor = action.get('item_at_cursor')
        inventory_index = action.get('inventory_index')
        move_cursor = action.get('move_cursor')
        move_page = action.get('move_page')
        exit = action.get('exit')

        inventory_shown = False

        player_turn_results = []

        if mv:
            dx, dy = mv
            dest_x = player.x + dx
            dest_y = player.y + dy

            if game_map.walkable[dest_x, dest_y]:
                target = blocking_entity_at_position(entities, dest_x, dest_y)
                if target:
                    atk_results = player.combat.attack(target)
                    player_turn_results.extend(atk_results)
                else:
                    move_results = {'move': mv}
                    player_turn_results.append(move_results)
        elif pickup:
            for e in entities:
                if e.item and e.x == player.x and e.y == player.y:
                    pickup_results = player.inventory.add_item(e)
                    player_turn_results.extend(pickup_results)
                    # only acquire one item at one turn
                    break
            else:
                msgbox.add("no_item")
        # Toggle Inventory screen
        elif show_inventory:
            msgbox.add("open inven")
            previous_game_state = game_state
            game_state = GameStates.SHOW_INVENTORY
            # FIXME: cursor, page should have a value limit
            # and it probably should be handled by menus, not here
        elif exit:
            # quit game
            # break
            pass
        if GameStates.SHOW_INVENTORY:
            if move_cursor:
                inventory_menu.next_item(move_cursor)
            elif move_page:
                inventory_menu.next_page(move_page)
            elif hide_inventory:
                game_state = previous_game_state
            elif item_at_cursor:
                item = inventory_menu.item_at_cursor()
                if item:  # TEMP FIX: check validity of choice at menu?
                    use_results = player.inventory.use_item(item)
                    player_turn_results.extend(use_results)
            elif inventory_index is not None:
                # unlike other inputs, inventory_index can be 0,
                # so compare it with "None"
                # check will cause input 0 to be ignored
                item = inventory_menu.item_at(inventory_index)
                if item:
                    use_results = player.inventory.use_item(item)
                    player_turn_results.extend(use_results)

        # evaluate results of player turn
        for result in player_turn_results:
            movement = result.get('move')
            dead_entity = result.get('dead')
            item_added = result.get('item_added')
            msg = result.get('msg')
            item_used = result.get('item_used')
            if movement:
                dx, dy = movement
                player.move(dx, dy)
                game_map.compute_fov(player)
            if item_added:
                entities.remove(item_added)
            if msg:
                msgbox.add(msg)
            if dead_entity == player:
                game_state = GameStates.PLAYER_DEAD
            if item_used:
                inventory_shown = True

            # toggle state only when something is done in PLAYERS_TURN
            game_state = GameStates.ENEMY_TURN

        if game_state == GameStates.ENEMY_TURN:
            # move those with ai modules
            enemies = (e for e in entities if e.ai)
            for e in enemies:

                e_turn_results = e.ai.take_turn(player, game_map, entities)

                # still a bit WET!
                for result in e_turn_results:
                    msg = result.get('msg')
                    dead_entity = result.get('dead')
                    if msg:
                        msgbox.add(msg)
                    if dead_entity == player:
                        game_state = GameStates.PLAYER_DEAD

        # check whether to return to beginning of loop
        if game_state == GameStates.PLAYER_DEAD:
            break
        elif game_state == GameStates.SHOW_INVENTORY:
            pass
        #if item was used at screen, keep the inventory opened
        elif inventory_shown:
            game_state = GameStates.SHOW_INVENTORY
        else:
            game_state = GameStates.PLAYERS_TURN
Пример #3
0
def main(wizard_args=None):

    argparser = argparse.ArgumentParser()
    ########################################################################################
    ### Presets
    argparser.add_argument('--exac_loc_always_no_noise',
                           type=bool,
                           default=False)
    argparser.add_argument('--exac_loc_always_with_noise',
                           type=bool,
                           default=False)
    argparser.add_argument('--blur_loc_always_no_noise',
                           type=bool,
                           default=False)
    argparser.add_argument('--blur_loc_always_with_noise',
                           type=bool,
                           default=False)
    argparser.add_argument('--exac_loc_50_no_noise', type=bool, default=False)

    argparser.add_argument('--exac_loc_always_no_noise_no_vis',
                           type=bool,
                           default=False)
    argparser.add_argument('--exac_loc_always_with_noise_no_vis',
                           type=bool,
                           default=False)
    argparser.add_argument('--blur_loc_always_no_noise_no_vis',
                           type=bool,
                           default=False)
    argparser.add_argument('--blur_loc_always_with_noise_no_vis',
                           type=bool,
                           default=False)
    argparser.add_argument('--exac_loc_50_no_noise_no_vis',
                           type=bool,
                           default=False)
    ### Changes by us
    argparser.add_argument('--footsteps', type=bool, default=False)
    argparser.add_argument('--po_bleeb', type=bool, default=False)
    argparser.add_argument('--filter_bleeb', type=bool, default=False)
    argparser.add_argument('--see_surrounding', type=bool, default=False)

    argparser.add_argument('--tourist_noise', type=float, default=0.01)
    argparser.add_argument('--po_scan_rate', type=float, default=0.10)

    argparser.add_argument('--extra_sensor_pa', type=bool, default=False)
    argparser.add_argument('--extra_sensor_po', type=bool, default=False)

    ### Environment
    argparser.add_argument('--row_num', type=int, default=3)
    argparser.add_argument('--column_num', type=int, default=3)
    argparser.add_argument('--ani_den_seed', type=int, default=66)

    ### Patroller
    argparser.add_argument('--pa_state_size', type=int, default=-1)  # 21
    argparser.add_argument('--pa_num_actions', type=int, default=5)

    ### Poacher CNN
    argparser.add_argument('--snare_num', type=int, default=3)
    argparser.add_argument('--po_state_size', type=int,
                           default=-1)  # add self footprint to poacher # 22
    argparser.add_argument('--po_num_actions', type=int, default=5)

    ### Poacher Rule Base, parameters set following advice from domain experts
    argparser.add_argument('--po_act_den_w', type=float, default=3.)
    argparser.add_argument('--po_act_enter_w', type=float, default=0.3)
    argparser.add_argument('--po_act_leave_w', type=float, default=-1.0)
    argparser.add_argument('--po_act_temp', type=float, default=5.0)
    argparser.add_argument('--po_home_dir_w', type=float, default=3.0)

    ### Training
    argparser.add_argument('--Delta',
                           type=float,
                           default=0.0,
                           help='the exploration rate in the meta-strategy')
    argparser.add_argument('--naive',
                           type=bool,
                           default=False,
                           help='whehter using naive PSRO')
    argparser.add_argument(
        '--advanced_training',
        type=bool,
        default=True,
        help='whether using dueling double DQN with graident clipping')
    argparser.add_argument('--map_type', type=str, default='random')
    argparser.add_argument(
        '--po_location',
        type=int,
        default=None,
        help='0, 1, 2, 3 for local modes; None for global mode')
    argparser.add_argument('--save_path',
                           type=str,
                           default='./Results_33_random/')

    argparser.add_argument('--pa_episode_num', type=int, default=300000)
    argparser.add_argument('--po_episode_num', type=int, default=300000)
    argparser.add_argument('--epi_num_incr', type=int,
                           default=0)  # no usage now
    argparser.add_argument('--final_incr_iter', type=int,
                           default=10)  # no usage now
    argparser.add_argument('--pa_replay_buffer_size', type=int, default=200000)
    argparser.add_argument('--po_replay_buffer_size', type=int, default=100000)
    argparser.add_argument('--test_episode_num', type=int, default=5000)
    argparser.add_argument('--iter_num',
                           type=int,
                           default=20,
                           help='DO iteraion num')
    argparser.add_argument('--load_path', type=str, default='./Results5x5/')
    argparser.add_argument('--load_num', type=int, default=0)
    argparser.add_argument('--pa_initial_lr', type=float, default=1e-4)
    argparser.add_argument('--po_initial_lr', type=float, default=5e-5)

    argparser.add_argument('--br_po_DQN_episode_num', type=int, default=500)
    argparser.add_argument('--print_every', type=int, default=50)
    argparser.add_argument('--zero_sum',
                           type=int,
                           default=1,
                           help='whether to set the game zero-sum')
    argparser.add_argument('--batch_size', type=int, default=32)
    argparser.add_argument('--target_update_every', type=int, default=2000)
    argparser.add_argument('--reward_gamma', type=float, default=0.95)
    argparser.add_argument('--save_every_episode', type=int,
                           default=200)  #10000)
    argparser.add_argument('--test_every_episode', type=int, default=10000)
    argparser.add_argument('--gui_every_episode', type=int, default=500)
    argparser.add_argument('--gui_test_num', type=int, default=20)
    argparser.add_argument('--gui', type=int, default=0)
    argparser.add_argument('--mix_every_episode', type=int, default=250)
    argparser.add_argument(
        '--epsilon_decrease',
        type=float,
        default=0.05,
        help='decrease of the epsilon exploration rate in DQN')
    argparser.add_argument('--PER',
                           type=bool,
                           default=False,
                           help='wheter to use prioterized experience replay')
    argparser.add_argument('--reward_shaping',
                           type=bool,
                           default=False,
                           help='whether to use reward shaping in training')

    argparser.add_argument('--max_time', type=int, default=100)
    #########################################################################################
    args = argparser.parse_args()

    if not args.po_bleeb and args.filter_bleeb:
        raise ValueError(
            'filter_bleeb cannot be true, while po_bleeb is false')

    #### PRESETS ####
    if args.exac_loc_always_no_noise:
        args.po_bleeb = True
        args.po_scan_rate = 1
        args.tourist_noise = 0
        args.filter_bleeb = False

        args.see_surrounding = True
        args.footsteps = False

        args.map_type = 'poacher'
        args.naive = True
        args.row_num = 7
        args.column_num = 7

    elif args.exac_loc_always_with_noise:
        args.po_bleeb = True
        args.po_scan_rate = 1
        args.tourist_noise = 0.05
        args.filter_bleeb = False

        args.see_surrounding = True
        args.footsteps = False

        args.map_type = 'poacher'
        args.naive = True
        args.row_num = 7
        args.column_num = 7

    elif args.blur_loc_always_no_noise:
        args.po_bleeb = True
        args.po_scan_rate = 1
        args.tourist_noise = 0
        args.filter_bleeb = True

        args.see_surrounding = True
        args.footsteps = False

        args.map_type = 'poacher'
        args.naive = True
        args.row_num = 7
        args.column_num = 7

    elif args.blur_loc_always_with_noise:
        args.po_bleeb = True
        args.po_scan_rate = 1
        args.tourist_noise = 0.05
        args.filter_bleeb = True

        args.see_surrounding = True
        args.footsteps = False

        args.map_type = 'poacher'
        args.naive = True
        args.row_num = 7
        args.column_num = 7

    elif args.exac_loc_50_no_noise:
        args.po_bleeb = True
        args.po_scan_rate = 0.5
        args.tourist_noise = 0
        args.filter_bleeb = False

        args.see_surrounding = True
        args.footsteps = False

        args.map_type = 'poacher'
        args.naive = True
        args.row_num = 7
        args.column_num = 7

    elif args.exac_loc_always_no_noise_no_vis:
        args.po_bleeb = True
        args.po_scan_rate = 1
        args.tourist_noise = 0
        args.filter_bleeb = False

        args.see_surrounding = False
        args.footsteps = False

        args.map_type = 'poacher'
        args.naive = True
        args.row_num = 7
        args.column_num = 7

    elif args.exac_loc_always_with_noise_no_vis:
        args.po_bleeb = True
        args.po_scan_rate = 1
        args.tourist_noise = 0.05
        args.filter_bleeb = False

        args.see_surrounding = False
        args.footsteps = False

        args.map_type = 'poacher'
        args.naive = True
        args.row_num = 7
        args.column_num = 7

    elif args.blur_loc_always_no_noise_no_vis:
        args.po_bleeb = True
        args.po_scan_rate = 1
        args.tourist_noise = 0
        args.filter_bleeb = True

        args.see_surrounding = False
        args.footsteps = False

        args.map_type = 'poacher'
        args.naive = True
        args.row_num = 7
        args.column_num = 7

    elif args.blur_loc_always_with_noise_no_vis:
        args.po_bleeb = True
        args.po_scan_rate = 1
        args.tourist_noise = 0.05
        args.filter_bleeb = True

        args.see_surrounding = False
        args.footsteps = False

        args.map_type = 'poacher'
        args.naive = True
        args.row_num = 7
        args.column_num = 7

    elif args.exac_loc_50_no_noise_no_vis:
        args.po_bleeb = True
        args.po_scan_rate = 0.5
        args.tourist_noise = 0
        args.filter_bleeb = False

        args.see_surrounding = False
        args.footsteps = False

        args.map_type = 'poacher'
        args.naive = True
        args.row_num = 7
        args.column_num = 7

    if wizard_args:
        for k, v in wizard_args.items():
            setattr(args, k, v)
    else:
        pass

    if args.po_state_size == -1:
        args.po_state_size = 14 + (8 * args.footsteps) + (
            1 * args.see_surrounding) + (1 * args.extra_sensor_po)

    if args.pa_state_size == -1:
        args.pa_state_size = 12 + (8 * args.footsteps) + (
            1 * args.po_bleeb) + (1 * args.see_surrounding) + (
                1 * args.extra_sensor_pa)

    print("ARGS:", args)

    ### END PRESETS ####

    if args.row_num == 10:
        args.column_num = 10
        args.max_time = 75
        args.pa_initial_lr = 1e-4
        args.po_initial_lr = 5e-5
        args.pa_replay_buffer_size = 200000
        args.po_replay_buffer_size = 100000
        if args.po_location is not None:
            args.pa_episode_num = 2000
            args.po_episode_num = 2000
    # test
    # if args.row_num == 7:
    #     args.column_num = 7
    #     args.max_time = 100
    #     args.pa_initial_lr = 1e-4
    #     args.po_initial_lr = 5e-5
    #     args.pa_replay_buffer_size = 200000
    #     args.po_replay_buffer_size = 100000
    #     if args.po_location is not None:
    #         args.pa_episode_num = 10000
    #         args.po_episode_num = 10000
    #
    # elif args.row_num == 5:
    #     args.column_num = 5
    #     args.max_time = 25
    #     args.pa_episode_num = 300000
    #     args.po_episode_num = 300000
    #     args.pa_initial_lr = 1e-4
    #     args.po_initial_lr = 5e-5
    #     args.pa_replay_buffer_size = 50000
    #     args.po_replay_buffer_size = 40000
    #     if args.po_location is not None:
    #         args.pa_episode_num = 200000
    #
    #
    # elif args.row_num == 3:
    #     args.column_num = 3
    #     args.max_time = 4
    #     args.snare_num = 3
    #     args.pa_episode_num = 500 #100000
    #     args.po_episode_num = 500 #100000
    #     args.pa_initial_lr = 5e-5
    #     args.po_initial_lr = 5e-5
    #     args.pa_replay_buffer_size = 200 #10000
    #     args.po_replay_buffer_size = 200 #8000
    #     if args.po_location is not None:
    #         args.pa_episode_num = 200 # 80000
    #         args.po_episode_num = 200 # 80000

    if args.naive:
        args.Delta = 0.0
        args.po_location = None
    else:
        pass

    # args.save_path = './' + str(args.pa_episode_num) + "_" + "filterbleeb:" + str(args.filter_bleeb) + "_touristnoise:" + \
    #                  str(args.tourist_noise) + "_footsteps:" + str(args.footsteps) + "_seesurrounding:" + str(args.see_surrounding) + \
    #                  "_poscanrate:" + str(args.po_scan_rate) + "_" + str(args.row_num) + "x" + str(args.column_num)

    if args.save_path and (not os.path.exists(args.save_path)):
        os.makedirs(args.save_path)

    with open(args.save_path + '/train_args.json', 'w') as f:
        f.write(json.dumps(vars(args)))

    paralog = open(args.save_path + '/paralog.txt', 'w')
    paralog.write('row_num {0} \n'.format(args.row_num))
    paralog.write('snare_num {0} \n'.format(args.snare_num))
    paralog.write('max_time {0} \n'.format(args.max_time))
    paralog.write('animal density seed {0} \n'.format(args.ani_den_seed))
    paralog.write('pa_initial_episode_num {0} \n'.format(args.pa_episode_num))
    paralog.write('po_initial_episode_num {0} \n'.format(args.po_episode_num))
    paralog.write('epi_num_incr {0} \n'.format(args.epi_num_incr))
    paralog.write('final_incr_iter {0} \n'.format(args.final_incr_iter))
    paralog.write('pa_replay_buffer_size {0} \n'.format(
        args.pa_replay_buffer_size))
    paralog.write('po_replay_buffer_size {0} \n'.format(
        args.po_replay_buffer_size))
    paralog.write('pa_initial_lr {0} \n'.format(args.pa_initial_lr))
    paralog.write('po_initial_lr {0} \n'.format(args.po_initial_lr))
    paralog.write('test_episode_num {0} \n'.format(args.test_episode_num))
    paralog.write('Delta {0} \n'.format(args.Delta))
    paralog.write('po_location {0} \n'.format(str(args.po_location)))
    paralog.write('map_type {0} \n'.format(str(args.map_type)))

    paralog.write('filter_bleeb {0} \n'.format(str(args.naive)))
    paralog.write('po_bleeb {0} \n'.format(str(args.naive)))
    paralog.write('naive {0} \n'.format(str(args.naive)))
    paralog.write('naive {0} \n'.format(str(args.naive)))
    paralog.flush()
    paralog.close()

    ################## for initialization ###########################
    global log_file

    log_file = open(args.save_path + '/log.txt', 'w')

    animal_density = generate_map(args)
    env_pa = Env(args,
                 animal_density,
                 cell_length=None,
                 canvas=None,
                 gui=False)
    env_po = Env(args,
                 animal_density,
                 cell_length=None,
                 canvas=None,
                 gui=False)

    patrollers = [
        Patroller_CNN(args, 'pa_model' + str(i))
        for i in range(args.iter_num + 1)
    ]
    poachers = [
        Poacher(args, 'po_model' + str(i)) for i in range(args.iter_num + 1)
    ]
    pa_type = ['DQN']
    po_type = ['DQN']

    ### initialize poachers needed for training a separate best-response poacher DQN
    br_poacher = Poacher(args, 'br_poacher')
    br_target_poacher = Poacher(args, 'br_target_poacher')
    br_good_poacher = Poacher(args, 'br_good_poacher')
    br_utility = np.zeros(2)

    if not args.naive:
        patrollers[0] = RandomSweepingPatroller(args, mode=args.po_location)
        pa_type[0] = 'RS'
    if not args.naive:
        poachers[0] = Poacher_h(args, animal_density)
        po_type[0] = 'PARAM'

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())

    ### copy ops needed for training a separate best-response poacher DQN
    br_po_copy_ops = tf_copy(br_target_poacher, br_poacher, sess)
    br_po_good_copy_ops = tf_copy(br_good_poacher, br_poacher, sess)

    pa_payoff = np.zeros((1, 1))
    po_payoff = np.zeros((1, 1))
    length = np.zeros((1, 1))

    pa_payoff[0, 0], po_payoff[0, 0], _ = simulate_payoff(patrollers,
                                                          poachers,
                                                          0,
                                                          0,
                                                          env_pa,
                                                          sess,
                                                          args,
                                                          pa_type=pa_type[0],
                                                          po_type=po_type[0])

    pa_strategy, po_strategy = np.array([1]), np.array([1])

    np.save(file=args.save_path + 'pa_strategy_iter_0', arr=pa_strategy)
    np.save(file=args.save_path + 'po_strategy_iter_0', arr=po_strategy)

    np.save(file=args.save_path + 'pa_payoff_iter_0', arr=pa_payoff)
    np.save(file=args.save_path + 'po_payoff_iter_0', arr=po_payoff)

    log_file.write('pa_payoff:\n' + str(pa_payoff) + '\n')
    log_file.write('po_payoff:\n' + str(po_payoff) + '\n')

    log_file.write('pa_strat:\n' + str(pa_strategy) + '\n')
    log_file.write('po_strat:\n' + str(po_strategy) + '\n')

    ############## starting DO ####################
    iteration = 1
    pa_pointer, po_pointer = 1, 1  # the pointer counting the number of strategies for pa and po.

    while (1):
        time_begin = time.time()

        pa_payoff, po_payoff, length = extend_payoff(pa_payoff, po_payoff,
                                                     length, pa_pointer + 1,
                                                     po_pointer + 1)
        po_type.append('DQN')
        pa_type.append('DQN')

        log_file.flush()

        print('\n' + 'NEW_ITERATION: ' + str(iteration) + '\n')
        log_file.write('\n' + 'NEW_ITERATION: ' + str(iteration) + '\n')

        ### compute the NE utility for both sides
        po_ne_utility = 0
        pa_ne_utility = 0
        for pa_strat in range(pa_pointer):
            for po_strat in range(po_pointer):
                po_ne_utility += pa_strategy[pa_strat] * po_strategy[
                    po_strat] * po_payoff[pa_strat, po_strat]
                pa_ne_utility += pa_strategy[pa_strat] * po_strategy[
                    po_strat] * pa_payoff[pa_strat, po_strat]

        log_file.write('last_pa_ne_utility:' + str(pa_ne_utility) + '\n')
        log_file.write('last_po_ne_utility:' + str(po_ne_utility) + '\n')
        pre_pa_strategy = pa_strategy
        pre_po_strategy = po_strategy

        ### compute the best response poacher utility
        ### 1. train a best response poacher DQN against the current pa strategy
        calc_po_best_response(br_poacher,
                              br_target_poacher,
                              br_po_copy_ops,
                              br_po_good_copy_ops,
                              patrollers,
                              pa_strategy,
                              pa_type,
                              iteration,
                              sess,
                              env_pa,
                              args,
                              br_utility,
                              0,
                              train_episode_num=args.br_po_DQN_episode_num)
        br_DQN_utility = br_utility[1]

        ### 2. test against the heuristic poacher stored in poachers[0]
        br_heuristic_utility = 0.
        for i in range(pa_pointer):
            _, po_utility, _ = simulate_payoff(patrollers,
                                               poachers,
                                               i,
                                               0,
                                               env_pa,
                                               sess,
                                               args,
                                               pa_type=pa_type[i],
                                               po_type=po_type[0])
            br_heuristic_utility += po_utility * pa_strategy[i]

        ### choose the better one
        better = 'DQN' if br_DQN_utility >= br_heuristic_utility else 'heuristic'
        br_poacher_utility = max(br_DQN_utility, br_heuristic_utility)
        log_file.write(
            'Iteration {0} poacher best response utility {1} poacher best response type {2} \n'
            .format(iteration, br_poacher_utility, better))
        print(
            'Iteration {0} poacher best response utility {1} poacher best response type {2}'
            .format(iteration, br_poacher_utility, better))

        ### train the best response agent
        ### using threading to accelerate the training
        good_patrollers = []
        good_poachers = []
        final_utility = [0.0, 0.0]
        target_patroller = Patroller_CNN(args,
                                         'target_patroller' + str(iteration))
        good_patroller = Patroller_CNN(args, 'good_patroller' + str(iteration))
        pa_copy_ops = tf_copy(target_patroller, patrollers[pa_pointer], sess)
        pa_good_copy_ops = tf_copy(good_patroller, patrollers[pa_pointer],
                                   sess)
        pa_inverse_ops = tf_copy(patrollers[pa_pointer], good_patroller, sess)

        target_poacher = Poacher(args, 'target_poacher' + str(iteration))
        good_poacher = Poacher(args, 'good_poacher' + str(iteration))
        po_copy_ops = tf_copy(target_poacher, poachers[po_pointer], sess)
        po_good_copy_ops = tf_copy(good_poacher, poachers[po_pointer], sess)
        po_inverse_ops = tf_copy(poachers[po_pointer], good_poacher, sess)

        funcs = [calc_pa_best_response, calc_po_best_response]
        params = [[
            patrollers[pa_pointer], target_patroller, pa_copy_ops,
            pa_good_copy_ops, poachers, po_strategy, po_type, iteration, sess,
            env_pa, args, final_utility, 0
        ],
                  [
                      poachers[po_pointer], target_poacher, po_copy_ops,
                      po_good_copy_ops, patrollers, pa_strategy, pa_type,
                      iteration, sess, env_po, args, final_utility, 0
                  ]]

        ### if the maximum iteration number is achieved
        if args.iter_num == iteration:
            log_file.write(
                '\n DO reaches terminating iteration {0}'.format(iteration) +
                '\n')
            log_file.write('Final Pa-payoff: \n' + str(pa_payoff) + '\n')
            log_file.write('Final Po-payoff: \n' + str(po_payoff) + '\n')
            log_file.write('Final pa_strat:\n' + str(pa_strategy) + '\n')
            log_file.write('Final po_strat:\n' + str(po_strategy) + '\n')
            log_file.write('Final pa_ne_utility:' + str(pa_ne_utility) + '\n')
            log_file.write('Final po_ne_utility:' + str(po_ne_utility) + '\n')
            log_file.flush()

            threads = []
            for i in range(2):
                process = Thread(target=funcs[i], args=params[i])
                process.start()
                threads.append(process)
            ### We now pause execution on the main thread by 'joining' all of our started threads.
            for process in threads:
                process.join()

            pa_exploit = final_utility[0] - pa_ne_utility
            po_exploit = final_utility[1] - po_ne_utility
            log_file.write('Final pa_best_response_utility:' +
                           str(final_utility[0]) + '\n')
            log_file.write('Final po_best_response_utility:' +
                           str(final_utility[1]) + '\n')
            log_file.write('Final pa exploitibility:' + str(pa_exploit) + '\n')
            log_file.write('Final po exploitibility:' + str(po_exploit) + '\n')
            break

        ### not the final iteration
        threads = []

        for i in range(2):
            process = Thread(target=funcs[i], args=params[i])
            process.start()
            threads.append(process)
        for process in threads:
            process.join()

        # calc_pa_best_response(patrollers[pa_pointer], target_patroller, pa_copy_ops, pa_good_copy_ops, poachers,
        #         po_strategy, iteration, sess, env_pa, args, final_utility,0)

        sess.run(pa_inverse_ops)
        sess.run(po_inverse_ops)

        for pa_strat in range(pa_pointer):
            pa_payoff[pa_strat, po_pointer ],po_payoff[pa_strat, po_pointer], _  = \
                simulate_payoff(patrollers, poachers, pa_strat, po_pointer, env_pa, sess, args,
                    pa_type=pa_type[pa_strat], po_type=po_type[po_pointer])

        for po_strat in range(po_pointer):
            pa_payoff[pa_pointer, po_strat],po_payoff[pa_pointer, po_strat],_  = \
                simulate_payoff(patrollers, poachers, pa_pointer, po_strat, env_pa, sess, args,
                pa_type=pa_type[pa_pointer], po_type = po_type[po_strat])

        pa_payoff[pa_pointer, po_pointer],po_payoff[pa_pointer, po_pointer],_  = \
            simulate_payoff(patrollers, poachers, pa_pointer, po_pointer, env_pa, sess, args,
            pa_type=pa_type[pa_pointer], po_type = po_type[po_pointer])

        pa_strategy, po_strategy = calc_NE_zero(pa_payoff, po_payoff,
                                                args.Delta)
        # pa_strategy, po_strategy = np.ones(iteration + 1) / (iteration + 1), np.ones(iteration + 1) / (iteration + 1)

        params[0][5] = po_strategy
        params[1][5] = pa_strategy

        po_best_response = final_utility[1]
        pa_best_response = final_utility[0]
        # for pa_strat in range(pa_pointer):
        #     po_best_response += pre_pa_strategy[pa_strat] * po_payoff[pa_strat, po_pointer]
        # for po_strat in range(po_pointer):
        #     pa_best_response += pre_po_strategy[po_strat] * pa_payoff[pa_pointer, po_strat]

        # eps_po.append(po_best_response - po_ne_utility)
        # eps_pa.append(pa_best_response - pa_ne_utility)

        log_file.write('In DO pa_best_utility:' + str(pa_best_response) + '\n')
        log_file.write('In DO po_best_utility:' + str(po_best_response) + '\n')
        # log_file.write('eps_pa: ' + str(eps_pa) + '\n')
        # log_file.write('eps_po: ' + str(eps_po) + '\n')

        ######### save models for this iteration #############
        save_name = args.save_path + 'iteration_' + str(
            iteration) + '_pa_model.ckpt'
        patrollers[pa_pointer].save(sess=sess, filename=save_name)
        save_name = args.save_path + 'iteration_' + str(
            iteration) + '_po_model.ckpt'
        poachers[po_pointer].save(sess=sess, filename=save_name)

        ### save payoff matrix and ne strategies
        np.save(file=args.save_path + 'pa_payoff_iter_' + str(iteration),
                arr=pa_payoff)
        np.save(file=args.save_path + 'po_payoff_iter_' + str(iteration),
                arr=po_payoff)
        np.save(file=args.save_path + 'pa_strategy_iter_' + str(iteration),
                arr=pa_strategy)
        np.save(file=args.save_path + 'po_strategy_iter_' + str(iteration),
                arr=po_strategy)

        log_file.write('pa_payoff:\n' + str(pa_payoff) + '\n')
        log_file.write('po_payoff:\n' + str(po_payoff) + '\n')
        log_file.write('pa_strategy:\n' + str(pa_strategy) + '\n')
        log_file.write('po_strategy:\n' + str(po_strategy) + '\n')

        iteration += 1
        pa_pointer += 1
        po_pointer += 1

        time_end = time.time()

        log_file.write('Using time: \n' + str(time_end - time_begin) + '\n')
        log_file.flush()

    log_file.close()
Пример #4
0
def main(wizard_args=None):

    argparser = argparse.ArgumentParser()
    ########################################################################################
    ### Presets
    argparser.add_argument('--exac_loc_always_no_noise',
                           type=bool,
                           default=False)
    argparser.add_argument('--exac_loc_always_with_noise',
                           type=bool,
                           default=False)
    argparser.add_argument('--blur_loc_always_no_noise',
                           type=bool,
                           default=False)
    argparser.add_argument('--blur_loc_always_with_noise',
                           type=bool,
                           default=False)
    argparser.add_argument('--exac_loc_50_no_noise', type=bool, default=False)

    argparser.add_argument('--exac_loc_always_no_noise_no_vis',
                           type=bool,
                           default=False)  #no files for this
    argparser.add_argument('--exac_loc_always_with_noise_no_vis',
                           type=bool,
                           default=False)  #no files for this
    argparser.add_argument('--blur_loc_always_no_noise_no_vis',
                           type=bool,
                           default=False)
    argparser.add_argument('--blur_loc_always_with_noise_no_vis',
                           type=bool,
                           default=False)
    argparser.add_argument('--exac_loc_50_no_noise_no_vis',
                           type=bool,
                           default=False)  #no files for this

    ### Changes by us
    argparser.add_argument('--footsteps', type=bool, default=False)
    argparser.add_argument('--po_bleeb', type=bool, default=False)
    argparser.add_argument('--filter_bleeb', type=bool, default=False)
    argparser.add_argument('--see_surrounding', type=bool, default=False)

    argparser.add_argument('--tourist_noise', type=float, default=0.01)
    argparser.add_argument('--po_scan_rate', type=float, default=0.10)

    argparser.add_argument('--extra_sensor_pa', type=bool, default=False)
    argparser.add_argument('--extra_sensor_po', type=bool, default=False)

    ### Test parameters
    argparser.add_argument('--pa_load_path', type=str, default='./Results5x5/')
    argparser.add_argument('--po_load_path', type=str, default='./Results5x5/')
    argparser.add_argument('--load', type=bool, default=False)

    ### Environment
    argparser.add_argument('--row_num', type=int, default=3)
    argparser.add_argument('--column_num', type=int, default=3)
    argparser.add_argument('--ani_den_seed', type=int, default=66)
    argparser.add_argument('--max_time', type=int, default=50)

    ### Patroller
    argparser.add_argument('--pa_state_size', type=int, default=-1)
    argparser.add_argument('--pa_num_actions', type=int, default=5)

    ### Poacher CNN
    argparser.add_argument('--snare_num', type=int, default=1)
    argparser.add_argument('--po_state_size', type=int,
                           default=-1)  # yf: add self footprint to poacher
    argparser.add_argument('--po_num_actions', type=int, default=5)

    ### Poacher Rule Base
    argparser.add_argument('--po_act_den_w', type=float, default=3.)
    argparser.add_argument('--po_act_enter_w', type=float, default=0.3)
    argparser.add_argument('--po_act_leave_w', type=float, default=-1.0)
    argparser.add_argument('--po_act_temp', type=float, default=5.0)
    argparser.add_argument('--po_home_dir_w', type=float, default=3.0)

    ### Training
    argparser.add_argument('--map_type', type=str, default='random')
    argparser.add_argument('--advanced_training', type=bool, default=True)
    argparser.add_argument('--save_path',
                           type=str,
                           default='./Results33Parandom/')

    argparser.add_argument('--naive', type=bool, default=False)
    argparser.add_argument('--pa_episode_num', type=int, default=300000)
    argparser.add_argument('--po_episode_num', type=int, default=300000)
    argparser.add_argument('--pa_initial_lr', type=float, default=1e-4)
    argparser.add_argument('--po_initial_lr', type=float, default=5e-5)
    argparser.add_argument('--epi_num_incr', type=int, default=0)
    argparser.add_argument('--final_incr_iter', type=int, default=10)
    argparser.add_argument('--pa_replay_buffer_size', type=int, default=200000)
    argparser.add_argument('--po_replay_buffer_size', type=int, default=100000)
    argparser.add_argument('--test_episode_num', type=int, default=20000)
    argparser.add_argument('--iter_num', type=int, default=10)
    argparser.add_argument('--po_location', type=int, default=None)
    argparser.add_argument('--Delta', type=float, default=0.0)

    argparser.add_argument('--print_every', type=int, default=50)
    argparser.add_argument('--zero_sum', type=int, default=1)
    argparser.add_argument('--batch_size', type=int, default=32)
    argparser.add_argument('--target_update_every', type=int, default=2000)
    argparser.add_argument('--reward_gamma', type=float, default=0.95)
    argparser.add_argument('--save_every_episode', type=int, default=5000)
    argparser.add_argument('--test_every_episode', type=int, default=2000)
    argparser.add_argument('--gui_every_episode', type=int, default=500)
    argparser.add_argument('--gui_test_num', type=int, default=20)
    argparser.add_argument('--gui', type=int, default=0)
    argparser.add_argument('--mix_every_episode', type=int,
                           default=250)  # new added
    argparser.add_argument('--epsilon_decrease', type=float,
                           default=0.05)  # new added
    argparser.add_argument('--reward_shaping', type=bool, default=False)
    argparser.add_argument('--PER', type=bool, default=False)

    #########################################################################################
    args = argparser.parse_args()

    if not args.po_bleeb and args.filter_bleeb:
        raise ValueError(
            'filter_bleeb cannot be true, while po_bleeb is false')

    #### PRESETS ####
    # print("HUH", args)
    # print("WIZARD:", wizard_args)
    # print("JAAA", args.exac_loc_always_with_noise)
    #

    # if args.row_num == 7:
    #     args.column_num = 7
    #     args.max_time = 75

    if wizard_args:
        for k, v in wizard_args.items():
            setattr(args, k, v)
    else:
        pass

    if args.exac_loc_always_no_noise:
        args.po_bleeb = True
        args.po_scan_rate = 1
        args.tourist_noise = 0
        args.filter_bleeb = False

        args.see_surrounding = True
        args.footsteps = False

    elif args.exac_loc_always_with_noise:
        # print("JA DIT TRIGGERED")
        args.po_bleeb = True
        args.po_scan_rate = 1
        args.tourist_noise = 0.05
        args.filter_bleeb = False

        args.column_num = 7
        args.row_num = 7
        args.map_type = "poacher"
        #log_file = open('./Results_33_random/log.txt', 'w')
        args.see_surrounding = True
        args.footsteps = False

    elif args.blur_loc_always_no_noise:
        args.po_bleeb = True
        args.po_scan_rate = 1
        args.tourist_noise = 0
        args.filter_bleeb = True

        args.column_num = 7
        args.row_num = 7
        args.map_type = "poacher"
        args.see_surrounding = True
        args.footsteps = False

    elif args.blur_loc_always_with_noise:
        args.po_bleeb = True
        args.po_scan_rate = 1
        args.tourist_noise = 0.05
        args.filter_bleeb = False

        args.column_num = 7
        args.row_num = 7
        args.map_type = "poacher"
        args.see_surrounding = True
        args.footsteps = False

    elif args.exac_loc_50_no_noise:
        args.po_bleeb = True
        args.po_scan_rate = 0.5
        args.tourist_noise = 0
        args.filter_bleeb = False

        args.see_surrounding = True
        args.footsteps = False

    elif args.exac_loc_always_no_noise_no_vis:
        args.po_bleeb = True
        args.po_scan_rate = 1
        args.tourist_noise = 0
        args.filter_bleeb = False

        args.see_surrounding = False
        args.footsteps = False

        args.map_type = 'poacher'
        args.naive = True
        args.row_num = 7
        args.column_num = 7

    elif args.exac_loc_always_with_noise_no_vis:
        args.po_bleeb = True
        args.po_scan_rate = 1
        args.tourist_noise = 0.05
        args.filter_bleeb = False

        args.see_surrounding = False
        args.footsteps = False

        args.map_type = 'poacher'
        args.naive = True
        args.row_num = 7
        args.column_num = 7

    elif args.blur_loc_always_no_noise_no_vis:
        args.po_bleeb = True
        args.po_scan_rate = 1
        args.tourist_noise = 0
        args.filter_bleeb = True

        args.see_surrounding = False
        args.footsteps = False

        args.map_type = 'poacher'
        args.naive = True
        args.row_num = 7
        args.column_num = 7

    elif args.blur_loc_always_with_noise_no_vis:
        args.po_bleeb = True
        args.po_scan_rate = 1
        args.tourist_noise = 0
        args.filter_bleeb = True

        args.see_surrounding = False
        args.footsteps = False

        args.map_type = 'poacher'
        args.naive = True
        args.row_num = 7
        args.column_num = 7

    elif args.exac_loc_50_no_noise_no_vis:
        args.po_bleeb = True
        args.po_scan_rate = 0.5
        args.tourist_noise = 0
        args.filter_bleeb = False

        args.see_surrounding = False
        args.footsteps = False

        args.map_type = 'poacher'
        args.naive = True
        args.row_num = 7
        args.column_num = 7

    if args.po_state_size == -1:
        args.po_state_size = 14 + (8 * args.footsteps) + (
            1 * args.see_surrounding) + (1 * args.extra_sensor_po)

    if args.pa_state_size == -1:
        args.pa_state_size = 12 + (8 * args.footsteps) + (
            1 * args.po_bleeb) + (1 * args.see_surrounding) + (
                1 * args.extra_sensor_pa)

    print("ARGS IN GUI:", args)
    ################## for initialization ###########################
    global log_file

    # log_file = open('./Results_33_random/log.txt', 'w')
    # log_file = open('./Results_33_random/log.txt', 'w')

    animal_density = generate_map(args)
    # env = Env(args, animal_density, cell_length=None, canvas=None, gui=False)

    patrollers = [Patroller_CNN(args, 'pa_model' + str(i)) for i in range(5)]
    poachers = [Poacher(args, 'po_model' + str(i)) for i in range(5)]

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())

    args.po_location = None
    ### load the DQN models you have trained
    if args.load:
        # poachers[0] = Poacher_h(args, animal_density)
        # patrollers[0] = Patroller_h(args, animal_density)

        poachers[1].load(sess, args.po_load_path)
        patrollers[1].load(sess, args.pa_load_path)

        test_gui(poachers[1],
                 patrollers[1],
                 sess,
                 args,
                 pa_type='DQN',
                 po_type='DQN')

    ### test the random sweeping patroller and the heuristic poacher
    else:
        poacher = Poacher_h(args, animal_density)
        patroller = RandomSweepingPatroller(args)
        test_gui(poacher, patroller, sess, args, pa_type='RS', po_type='PARAM')
Пример #5
0
def test_gui(poacher, patroller, sess, args, pa_type, po_type):
    """
    doc
    """
    #########################################################################################
    global e
    global t
    global episode_reward
    global pa_total_reward, po_total_reward, game_len, pa_episode_reward, po_episode_reward
    global pa_state, po_state, pa_action

    pa_total_reward, po_total_reward, game_len = [], [], []
    pa_action = 'still'

    master = Tk()
    cell_length = 80
    canvas_width = args.column_num * cell_length
    canvas_height = args.row_num * cell_length
    canvas = Canvas(master=master, width=canvas_width, height=canvas_height)
    canvas.grid()

    # animal_density = Mountainmap(args.row_num, args.column_num)
    # np.random.seed(args.ani_den_seed)
    # animal_density = np.random.uniform(low=0.2, high=1., size=[args.row_num, args.column_num])
    animal_density = generate_map(args)
    TestEnv = Env(args,
                  animal_density,
                  cell_length=cell_length,
                  canvas=canvas,
                  gui=True)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    e = 0
    t = 0

    def run_step():
        global e, t, pa_total_reward, po_total_reward, game_len, pa_episode_reward, po_episode_reward
        global pa_state, po_state, pa_action

        if t == 0:
            print('reset')
            poacher.reset_snare_num()
            pa_state, po_state = TestEnv.reset_game()
            pa_episode_reward, po_episode_reward = 0., 0.

        # poacher take actions. Doing so is due to the different API provided by the DQN and heuristic agent
        if po_type == 'DQN':
            # the poacher can take actions only if he is not caught yet/has not returned home
            if not TestEnv.catch_flag and not TestEnv.home_flag:
                po_state = np.array([po_state])
                snare_flag, po_action = poacher.infer_action(sess=sess,
                                                             states=po_state,
                                                             policy="greedy")
            else:
                snare_flag = 0
                po_action = 'still'
        elif po_type == 'PARAM':
            po_loc = TestEnv.po_loc
            if not TestEnv.catch_flag and not TestEnv.home_flag:
                snare_flag, po_action = poacher.infer_action(
                    loc=po_loc,
                    local_trace=TestEnv.get_local_pa_trace(po_loc),
                    local_snare=TestEnv.get_local_snare(po_loc),
                    initial_loc=TestEnv.po_initial_loc)
            else:
                snare_flag = 0
                po_action = 'still'

        # patroller take actions
        if pa_type == 'DQN':
            pa_state = np.array([pa_state])
            pa_action = patroller.infer_action(sess=sess,
                                               states=pa_state,
                                               policy="greedy")
        elif pa_type == 'PARAM':
            pa_loc = TestEnv.pa_loc
            pa_action = patroller.infer_action(
                pa_loc, TestEnv.get_local_po_trace(pa_loc), 1.5, -2.0, 8.0)
        elif pa_type == 'RS':
            pa_loc = TestEnv.pa_loc
            footprints = []
            actions = ['up', 'down', 'left', 'right']
            for i in range(4, 8):
                if TestEnv.po_trace[pa_loc[0], pa_loc[1]][i] == 1:
                    footprints.append(actions[i - 4])
            pa_action = patroller.infer_action(pa_loc, pa_action, footprints)

        # the TestEnv moves on a step
        pa_state, pa_reward, po_state, po_reward, end_game = \
            TestEnv.step(pa_action, po_action, snare_flag, train = False)

        # print('poacher snare:', snare_flag)
        # # time.sleep(1)

        # accmulate the reward
        pa_episode_reward += pa_reward
        po_episode_reward += po_reward

        # the game ends if the end_game condition is true, or the maximum time step is achieved
        if end_game or (t == args.max_time - 1):
            info = "episode\t%s\tlength\t%s\tpatroller_total_reward\t%s\tpoacher_total_reward\t%s" % \
                   (e, t + 1, pa_episode_reward, po_episode_reward)
            print(info)
            pa_total_reward.append(pa_episode_reward)
            game_len.append(t + 1)
            po_total_reward.append(po_episode_reward)
            t = 0
            e += 1
            if e == args.gui_test_num:
                master.destroy()
                return
        else:
            t += 1

        master.after(500, run_step)

    run_step()
    master.mainloop()

    #print(np.mean(ret_total_reward), np.mean(ret_average_reward), np.mean(ret_length))
    return np.mean(pa_total_reward), np.mean(po_total_reward), np.mean(
        game_len)
Пример #6
0
def main():
    argparser = argparse.ArgumentParser(sys.argv[0])
    #########################################################################################
    # Environment
    argparser.add_argument('--row_num', type=int, default=7)
    argparser.add_argument('--column_num', type=int, default=7)
    argparser.add_argument('--ani_den_seed', type=int, default=66)
    argparser.add_argument('--zero_sum', type=int, default=1)
    argparser.add_argument('--reward_shaping', type=bool, default=False)
    argparser.add_argument('--po_location', type=int, default=None)
    argparser.add_argument('--map_type', type=str, default='random')

    # Patroller
    argparser.add_argument('--pa_state_size',
                           type=int,
                           default=20,
                           help="patroller state dimension")
    argparser.add_argument('--pa_num_actions',
                           type=int,
                           default=5,
                           help="still, up, down, left, right")

    # Poacher
    argparser.add_argument('--po_state_size',
                           type=int,
                           default=22,
                           help="poacher state dimension")
    argparser.add_argument('--po_num_actions',
                           type=int,
                           default=10,
                           help="still, up, down, left, right x put, not put")
    argparser.add_argument('--snare_num', type=int, default=6)
    argparser.add_argument('--po_act_den_w', type=float, default=3.)
    argparser.add_argument('--po_act_enter_w', type=float, default=0.3)
    argparser.add_argument('--po_act_leave_w', type=float, default=-1.0)
    argparser.add_argument('--po_act_temp',
                           type=float,
                           default=5.0,
                           help="softmax temperature")
    argparser.add_argument('--po_home_dir_w', type=float, default=3.0)

    # Training
    argparser.add_argument('--save_dir',
                           type=str,
                           default='./pg_models_pa/',
                           help='models_pa/')
    argparser.add_argument('--pa_load_dir',
                           type=str,
                           default='None',
                           help='models_pa/model.ckpt')
    argparser.add_argument('--log_file', type=str, default='log_train.txt')
    argparser.add_argument('--episode_num', type=int, default=300000)
    argparser.add_argument('--max_time',
                           type=int,
                           default=75,
                           help='maximum time step per episode')
    argparser.add_argument('--train_every_episode', type=int, default=30)
    argparser.add_argument('--target_update_every_episode',
                           type=int,
                           default=100,
                           help="for state value function")
    argparser.add_argument('--reward_gamma', type=float, default=0.95)
    argparser.add_argument('--initial_lr', type=float, default=1e-5)
    argparser.add_argument('--save_every_episode', type=int, default=10000)
    argparser.add_argument('--test_every_episode', type=int, default=20000)
    argparser.add_argument('--test_episode_num', type=int, default=5000)
    #########################################################################################
    args = argparser.parse_args()
    if args.pa_load_dir == 'None':
        args.pa_load_dir = None
    if args.save_dir == 'None':
        args.save_dir = None

    if args.row_num == 3:
        args.column_num = 3
        args.max_time = 4
        args.snare_num = 3

    if args.row_num == 5:
        args.column_num = 5
        args.max_time = 25

    if args.row_num == 7:
        args.column_num = 7
        args.max_time = 75

    # get animal density
    animal_density = generate_map(args)
    env = Env(args, animal_density, cell_length=None, canvas=None, gui=False)
    poacher = Poacher(args, animal_density)
    patroller_value = PatrollerValue(args, "pa_value_model")
    # target_patroller_value = PatrollerValue(args, 'pa_value_target')
    patroller_policy = PatrollerPolicy(args, "pa_policy")

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())

    # Load model if necessary
    if args.pa_load_dir:
        patroller_policy.load(sess=sess, filename=args.pa_load_dir)
        print('Load model for patroller from ' + args.pa_load_dir)

    if args.save_dir and (not os.path.exists(args.save_dir)):
        os.mkdir(args.save_dir)

    # Running Initialization
    log = open(args.save_dir + args.log_file, 'w')
    test_log = open(args.save_dir + 'test_log.txt', 'w')

    learning_rate = args.initial_lr
    action_id = {'still': 0, 'up': 1, 'down': 2, 'left': 3, 'right': 4}

    copy_ops = []
    # for target_w, model_w in zip(target_patroller_value.variables, patroller_value.variables):
    #     op = target_w.assign(model_w)
    #     copy_ops.append(op)
    # sess.run(copy_ops)
    # print("Update target value network parameter!")

    train_pre_state = []
    train_action = []
    train_reward = []
    train_post_state = []

    for e in range(args.episode_num):
        if e % 500000 == 0:
            learning_rate = max(0.0000001, learning_rate / 2.)

        # reset the environment
        poacher.reset_snare_num()
        pa_state, _ = env.reset_game()
        episode_reward = 0.

        for t in range(args.max_time):
            po_loc = env.po_loc
            if not env.catch_flag:
                snare_flag, po_action = poacher.infer_action(
                    loc=po_loc,
                    local_trace=env.get_local_pa_trace(po_loc),
                    local_snare=env.get_local_snare(po_loc),
                    initial_loc=env.po_initial_loc)
            else:
                snare_flag = 0
                po_action = 'still'

            train_pre_state.append(pa_state)
            pa_state = np.array(
                [pa_state])  # Make it 2-D, i.e., [batch_size(1), state_size]
            pa_action = patroller_policy.infer_action(sess=sess,
                                                      states=pa_state)
            train_action.append(action_id[pa_action])

            # the game moves on a step.
            pa_state, pa_reward, po_state, _, end_game = \
              env.step(pa_action, po_action, snare_flag)

            train_reward.append(pa_reward)

            episode_reward += pa_reward

            # Get new state
            train_post_state.append(pa_state)

            if end_game:
                info = "episode\t%s\tlength\t%s\ttotal_reward\t%s\taverage_reward\t%s" % \
                       (e, t + 1, episode_reward, 1. * episode_reward / (t + 1))
                print(info)
                log.write(info + '\n')
                log.flush()
                break

        # Train
        if e > 0 and e % args.train_every_episode == 0:
            # Fit value function
            post_state_value = patroller_value.get_state_value(
                sess=sess, states=train_post_state)
            state_value_target = np.array(
                train_reward) + args.reward_gamma * np.array(post_state_value)
            feed = {
                patroller_value.input_state: train_pre_state,
                patroller_value.state_values_target: state_value_target,
                patroller_value.learning_rate: learning_rate
            }
            sess.run(patroller_value.train_op, feed_dict=feed)

            # Get advantage value
            pre_state_value = patroller_value.get_state_value(
                sess=sess, states=train_pre_state)
            advantage = np.array(train_reward) + args.reward_gamma * np.array(post_state_value) - \
                        np.array(pre_state_value)

            # Train policy
            feed = {
                patroller_policy.input_state: train_pre_state,
                patroller_policy.actions: train_action,
                patroller_policy.advantage: advantage,
                patroller_policy.learning_rate: learning_rate
            }
            sess.run(patroller_policy.train_op, feed_dict=feed)

            # Clear the training buffer
            train_pre_state = []
            train_action = []
            train_reward = []
            train_post_state = []

        # Test
        if e > 0 and e % args.test_every_episode == 0:
            test_total_reward, test_average_reward, test_length = test(
                poacher, patroller_policy, env, sess, args)
            info = [test_total_reward, test_average_reward, test_length]
            info = [str(x) for x in info]
            info = '\t'.join(info) + '\n'
            print(info)
            test_log.write(info)
            test_log.flush()

        # Update target
        # if e > 0 and e % args.target_update_every_episode == 0:
        #     sess.run(copy_ops)
        #     print("Update target value network parameter!")

        # Save model
        if e > 0 and e % args.save_every_episode == 0:
            save_name = os.path.join(args.save_dir, str(e), "model.ckpt")
            patroller_policy.save(sess=sess, filename=save_name)
            print('Save model to ' + save_name)

    test_log.close()
    log.close()
Пример #7
0
def main():
    ''' Main function, initalizes various variables and contains the main program loop.
        Should not be called any other way than running the file.
        Takes no arguments and returns nothing.
    '''
    # initialize pygame
    pygame.init()
    
    # Make map
    maps.generate_map()
    # Initiate player
    g.special_entity_list["player"] = players.Player(g.player_start_x, g.player_start_y)
    # Creates a window just the size to fit all the tiles in the map file.
    pygame.display.set_icon(g.images["icon"].get())
    pygame.display.set_caption("TileGame by ZeeQyu", "TileGame")
    g.screen = pygame.display.set_mode((g.width * c.TILE_SIZE,
                                              g.height * c.TILE_SIZE))
    
    # A variable for skipping a single cycle after f.ex. accessing a menu, so that
    # the entities won't fly across the screen
    skip_cycle = False
    
    # Get time once initially and make time variables
    time_last_tick = time_start = time_prev = time.clock()
    time_start = time_cycles = time_updates = time_last_sleep = 0
    menu = interface.BuildMenu()
    
    # Main loop
    while True:
        # Make the screen update every frame
        if c.FORCE_UPDATE:
            g.force_update = True
        # Event checker. Allows closing of the program and passes keypresses to the player instance
        for event in pygame.event.get():
            # Quit code
            if event.type == pgl.QUIT:
                sys.exit()
            if event.type == pgl.KEYDOWN or event.type == pgl.KEYUP:
                # Create beetle with (default) space
                if event.type == pgl.KEYDOWN and event.key == g.key_dict["spawn_beetle"][0]:
                    g.entity_list.append(units.Beetle(g.special_entity_list["player"].x,
                                                            g.special_entity_list["player"].y))
                # Duplicate all beetles with (default) D
                elif event.type == pgl.KEYDOWN and event.key == g.key_dict["duplicate_beetles"][0]:
                    # Make an empty list to temporarily store the added beetles, so no infinite loop appears
                    temp_entity_list = []
                    for entity in g.entity_list:
                        if type(entity) == units.Beetle:
                            temp_entity_list.append(units.Beetle(entity.x, entity.y))
                    g.entity_list.extend(temp_entity_list)
                    temp_entity_list = []
                # Remove all beetles
                elif event.type == pgl.KEYDOWN and event.key == g.key_dict["remove_beetles"][0]:
                    # Loop backwards through the g.entity_list
                    for i in range(len(g.entity_list)-1, -1, -1):
                        if type(g.entity_list[i]) == units.Beetle:
                            del g.entity_list[i]
                    g.force_update = True
                # Key configuration
                elif event.type == pgl.KEYDOWN and event.key == c.CONFIG_KEYS_KEY:
                    skip_cycle = g.force_update = True
                    interface.key_reconfig()
                # Otherwise, check for if the player should move
                g.special_entity_list["player"].event_check(event)
                
        # Tick: Make sure certain things happen on a more regular basis than every frame 
        time_now = time.clock()
        time_diff = time_now - time_prev
        # If the time has been more than two seconds, movement might jerk out, so a cycle should be skipped
        if time_prev + 2 < time_now:
            skip_cycle = True
        time_prev = time_now
        # Skip the rest of this cycle if a menu was accessed until now
        if skip_cycle:
            skip_cycle = False
            continue
        # FPS meter (shown in console), checks the amount of times this code is run every second and prints that every second.
        time_cycles += 1
        if time_start + 1 < time_now:
            if time_updates == 1 and time_cycles == 1:
                time_updates = 1.0 / (time_diff)
            print time_start, "seconds from start,",  time_cycles, "cycles,", time_updates, "fps"
            time_cycles = 0
            time_updates = 0
            time_start = time_now
        # What happens every tick?
        if time_last_tick + c.TICK_FREQ < time_now:
            time_last_tick = time_last_tick + c.TICK_FREQ
            # Tick all the entites (let them do whatever they do every tick
            for i in range(len(g.entity_list)-1, -1, -1):
                entity = g.entity_list[i]
                entity.tick()
            for entity in g.special_entity_list.values():
                entity.tick()
            for tile in g.tick_tiles:
                g.map[tile[0]][tile[1]].tick()
        # Make sure the loop doesn't go too quickly and bog the processor down
        if time_last_sleep < c.SLEEP_TIME:
            time.sleep(c.SLEEP_TIME -  time_last_sleep)

        # Update map buffer if needed
        if g.update_map:
            g.update_map = False
            g.force_update = True
            g.map_screen_buffer = maps.update_map()
        # update all entities
        entity_has_moved = False
        if g.entity_list:
            for i in range(len(g.entity_list)-1, -1, -1):
                entity = g.entity_list[i]
                entity.update(time_diff)
                # Check if any of them have moved
                if entity.has_moved():
                    entity_has_moved = True
        if g.special_entity_list:
            for entity in g.special_entity_list.values():
                # Update all enties and check for if any of them is a package that just finished moving.
                # If so, skip the has_moved check.
                if entity.update(time_diff) == "deleted":
                    continue
                if entity.has_moved():
                    entity_has_moved = True
        
        # If any entity moved, redraw the screen
        if entity_has_moved or g.force_update:
            g.force_update = False
            time_updates += 1
            g.screen.fill(c.BACKGROUND_COLOR)
            # Draw the map buffer on the screen
            g.screen.blit(g.map_screen_buffer, (0, 0))
            # Draw the entities
            for i in range(len(g.entity_list)-1, -1, -1):
                entity = g.entity_list[i]
                entity.paint()
            for i in range(len(g.special_entity_list.values())-1, -1, -1):
                entity = g.special_entity_list.values()[i]
                entity.paint()
            #menu.paint()
            # Update the display
            pygame.display.flip()