Example #1
0
def args_gnubg(args):
    model_agent0 = args.model_agent0
    model_type = args.type
    hidden_units_agent0 = args.hidden_units_agent0
    n_episodes = args.episodes
    host = args.host
    port = args.port
    difficulty = args.difficulty

    if path_exists(model_agent0):
        # assert os.path.exists(model_agent0), print("The path {} doesn't exists".format(model_agent0))
        if model_type == 'nn':
            net0 = TDGammon(hidden_units=hidden_units_agent0,
                            lr=0.1,
                            lamda=None,
                            init_weights=False)
        else:
            net0 = TDGammonCNN(lr=0.0001)

        net0.load(checkpoint_path=model_agent0,
                  optimizer=None,
                  eligibility_traces=False)

        gnubg_interface = GnubgInterface(host=host, port=port)
        gnubg_env = GnubgEnv(gnubg_interface,
                             difficulty=difficulty,
                             model_type=model_type)
        evaluate_vs_gnubg(agent=TDAgentGNU(WHITE,
                                           net=net0,
                                           gnubg_interface=gnubg_interface),
                          env=gnubg_env,
                          n_episodes=n_episodes)
Example #2
0
def args_gnubg(args):
    model_agent0 = args.model_agent0
    model_type = args.type
    hidden_units_agent0 = args.hidden_units_agent0
    n_episodes = args.episodes
    host = args.host
    port = args.port
    difficulty = args.difficulty
    iterations = args.iterations

    experiment = "/saved_models/" + model_agent0
    folder = os.getcwd() + experiment
    directory = os.fsencode(folder)

    max_it_sizes = 0
    for file in os.listdir(directory):
        filename = os.fsdecode(file)

        # Find by chosen iteration amount
        if iterations:
            if filename.endswith("{}.tar".format(str(iterations))):
                final_file = filename
                break

        # Otherwise the biggest iteration amount
        else:
            if filename.endswith(".tar"):
                size = filename.split('_')[-1][:-4]
                if int(size) > max_it_sizes:
                    max_it_sizes = int(size)

        if filename.endswith("{}.tar".format(str(max_it_sizes))):
            final_file = filename

    if path_exists(folder + '/' + final_file):
        # assert os.path.exists(model_agent0), print("The path {} doesn't exists".format(model_agent0))
        if model_type == 'nn':
            net0 = TDGammon(hidden_units=hidden_units_agent0,
                            lr=0.1,
                            lamda=None,
                            init_weights=False)
        else:
            net0 = TDGammonCNN(lr=0.0001)

        net0.load(checkpoint_path=folder + '/' + final_file,
                  optimizer=None,
                  eligibility_traces=False)

        gnubg_interface = GnubgInterface(host=host, port=port)
        gnubg_env = GnubgEnv(gnubg_interface,
                             difficulty=difficulty,
                             model_type=model_type)
        evaluate_vs_gnubg(agent=TDAgentGNU(WHITE,
                                           net=net0,
                                           gnubg_interface=gnubg_interface),
                          env=gnubg_env,
                          n_episodes=n_episodes,
                          difficulty=difficulty,
                          model=model_agent0)
Example #3
0
def args_plot(args, parser):
    '''
    This method is used to plot the number of time an agent wins when it plays against an opponent.
    Instead of evaluating the agent during training (it can require some time and slow down the training), I decided to plot the wins separately, loading the different
    model saved during training.
    For example, suppose I run the training for 100 games and save my model every 10 games.
    Later I will load these 10 models, and for each of them, I will compute how many times the agent would win against an opponent.
    :return: None
    '''

    src = args.save_path
    hidden_units = args.hidden_units
    n_episodes = args.episodes
    opponents = args.opponent.split(',')
    host = args.host
    port = args.port
    difficulties = args.difficulty.split(',')
    model_type = args.type

    if path_exists(src):
        # assert os.path.exists(src), print("The path {} doesn't exists".format(src))

        for d in difficulties:
            if d not in [
                    'beginner', 'intermediate', 'advanced', 'world_class'
            ]:
                parser.error(
                    "--difficulty should be (one or more of) 'beginner','intermediate', 'advanced' ,'world_class'"
                )

        dst = args.dst

        if 'gnubg' in opponents and (not host or not port):
            parser.error(
                "--host and --port are required when 'gnubg' is specified in --opponent"
            )

        for root, dirs, files in os.walk(src):
            global_step = 0
            files = sorted(files)

            writer = SummaryWriter(dst)

            for file in files:
                if ".tar" in file:
                    print("\nLoad {}".format(os.path.join(root, file)))

                    if model_type == 'nn':
                        net = TDGammon(hidden_units=hidden_units,
                                       lr=0.1,
                                       lamda=None,
                                       init_weights=False)
                        env = gym.make('gym_backgammon:backgammon-v0')
                    else:
                        net = TDGammonCNN(lr=0.0001)
                        env = gym.make('gym_backgammon:backgammon-pixel-v0')

                    net.load(checkpoint_path=os.path.join(root, file),
                             optimizer=None,
                             eligibility_traces=False)

                    if 'gnubg' in opponents:
                        tag_scalar_dict = {}

                        gnubg_interface = GnubgInterface(host=host, port=port)

                        for difficulty in difficulties:
                            gnubg_env = GnubgEnv(gnubg_interface,
                                                 difficulty=difficulty,
                                                 model_type=model_type)
                            wins = evaluate_vs_gnubg(agent=TDAgentGNU(
                                WHITE,
                                net=net,
                                gnubg_interface=gnubg_interface),
                                                     env=gnubg_env,
                                                     n_episodes=n_episodes)
                            tag_scalar_dict[difficulty] = wins[WHITE]

                        writer.add_scalars('wins_vs_gnubg/', tag_scalar_dict,
                                           global_step)

                        with open(root + '/results.txt', 'a') as f:
                            print("{};".format(file) + str(tag_scalar_dict),
                                  file=f)

                    if 'random' in opponents:
                        tag_scalar_dict = {}
                        agents = {
                            WHITE: TDAgent(WHITE, net=net),
                            BLACK: RandomAgent(BLACK)
                        }
                        wins = evaluate_agents(agents, env, n_episodes)

                        tag_scalar_dict['random'] = wins[WHITE]

                        writer.add_scalars('wins_vs_random/', tag_scalar_dict,
                                           global_step)

                    global_step += 1

                    writer.close()