예제 #1
0
def run_MCTS(args, start_idx=0, iteration=0):
    net_to_play="%s_iter%d.pth.tar" % (args.neural_net_name, iteration)
    net = ConnectNet()
    cuda = torch.cuda.is_available()
    if cuda:
        net.cuda()
    
    if args.MCTS_num_processes > 1:
        logger.info("Preparing model for multi-process MCTS...")
        mp.set_start_method("spawn",force=True)
        net.share_memory()
        net.eval()
    
        current_net_filename = os.path.join("./model_data/",\
                                        net_to_play)
        if os.path.isfile(current_net_filename):
            checkpoint = torch.load(current_net_filename)
            net.load_state_dict(checkpoint['state_dict'])
            logger.info("Loaded %s model." % current_net_filename)
        else:
            torch.save({'state_dict': net.state_dict()}, os.path.join("./model_data/",\
                        net_to_play))
            logger.info("Initialized model.")
        
        processes = []
        if args.MCTS_num_processes > mp.cpu_count():
            num_processes = mp.cpu_count()
            logger.info("Required number of processes exceed number of CPUs! Setting MCTS_num_processes to %d" % num_processes)
        else:
            num_processes = args.MCTS_num_processes
        
        logger.info("Spawning %d processes..." % num_processes)
        with torch.no_grad():
            for i in range(num_processes):
                p = mp.Process(target=MCTS_self_play, args=(net, args.num_games_per_MCTS_process,
                                                            start_idx, i, args, iteration))
                p.start()
                processes.append(p)
            for p in processes:
                p.join()
        logger.info("Finished multi-process MCTS!")
    
    elif args.MCTS_num_processes == 1:
        logger.info("Preparing model for MCTS...")
        net.eval()
        
        current_net_filename = os.path.join("./model_data/",\
                                        net_to_play)
        if os.path.isfile(current_net_filename):
            checkpoint = torch.load(current_net_filename)
            net.load_state_dict(checkpoint['state_dict'])
            logger.info("Loaded %s model." % current_net_filename)
        else:
            torch.save({'state_dict': net.state_dict()}, os.path.join("./model_data/",\
                        net_to_play))
            logger.info("Initialized model.")
        
        with torch.no_grad():
            MCTS_self_play(net, args.num_games_per_MCTS_process, start_idx, 0, args, iteration)
        logger.info("Finished MCTS!")
예제 #2
0
def evaluate_nets(args, iteration_1, iteration_2):
    logger.info("Loading nets...")
    current_net = "%s_iter%d.pth.tar" % (args.neural_net_name, iteration_2)
    best_net = "%s_iter%d.pth.tar" % (args.neural_net_name, iteration_1)
    current_net_filename = os.path.join("./model_data/",\
                                    current_net)
    best_net_filename = os.path.join("./model_data/",\
                                    best_net)

    logger.info("Current net: %s" % current_net)
    logger.info("Previous (Best) net: %s" % best_net)

    current_cnet = ConnectNet()
    best_cnet = ConnectNet()
    cuda = torch.cuda.is_available()
    if cuda:
        current_cnet.cuda()
        best_cnet.cuda()

    if not os.path.isdir("./evaluator_data/"):
        os.mkdir("evaluator_data")

    if args.MCTS_num_processes > 1:
        mp.set_start_method("spawn", force=True)

        current_cnet.share_memory()
        best_cnet.share_memory()
        current_cnet.eval()
        best_cnet.eval()

        checkpoint = torch.load(current_net_filename)
        current_cnet.load_state_dict(checkpoint['state_dict'])
        checkpoint = torch.load(best_net_filename)
        best_cnet.load_state_dict(checkpoint['state_dict'])

        processes = []
        if args.MCTS_num_processes > mp.cpu_count():
            num_processes = mp.cpu_count()
            logger.info(
                "Required number of processes exceed number of CPUs! Setting MCTS_num_processes to %d"
                % num_processes)
        else:
            num_processes = args.MCTS_num_processes
        logger.info("Spawning %d processes..." % num_processes)
        with torch.no_grad():
            for i in range(num_processes):
                p = mp.Process(target=fork_process,
                               args=(arena(current_cnet, best_cnet),
                                     args.num_evaluator_games, i))
                p.start()
                processes.append(p)
            for p in processes:
                p.join()

        wins_ratio = 0.0
        for i in range(num_processes):
            stats = load_pickle("wins_cpu_%i" % (i))
            wins_ratio += stats['best_win_ratio']
        wins_ratio = wins_ratio / num_processes
        if wins_ratio >= 0.55:
            return iteration_2
        else:
            return iteration_1

    elif args.MCTS_num_processes == 1:
        current_cnet.eval()
        best_cnet.eval()
        checkpoint = torch.load(current_net_filename)
        current_cnet.load_state_dict(checkpoint['state_dict'])
        checkpoint = torch.load(best_net_filename)
        best_cnet.load_state_dict(checkpoint['state_dict'])
        arena1 = arena(current_cnet=current_cnet, best_cnet=best_cnet)
        arena1.evaluate(num_games=args.num_evaluator_games, cpu=0)

        stats = load_pickle("wins_cpu_%i" % (0))
        if stats.best_win_ratio >= 0.55:
            return iteration_2
        else:
            return iteration_1
    best_net = "current_net_trained1_iter2.pth.tar"
    current_net_filename = os.path.join("./model_data/",\
                                    current_net)
    best_net_filename = os.path.join("./model_data/",\
                                    best_net)
    current_cnet = ConnectNet()
    best_cnet = ConnectNet()
    cuda = torch.cuda.is_available()
    if cuda:
        current_cnet.cuda()
        best_cnet.cuda()

    if multiprocessing == 1:
        mp.set_start_method("spawn", force=True)

        current_cnet.share_memory()
        best_cnet.share_memory()
        current_cnet.eval()
        best_cnet.eval()

        checkpoint = torch.load(current_net_filename)
        current_cnet.load_state_dict(checkpoint['state_dict'])
        checkpoint = torch.load(best_net_filename)
        best_cnet.load_state_dict(checkpoint['state_dict'])

        processes = []
        for i in range(6):
            p = mp.Process(target=fork_process,
                           args=(arena(current_cnet, best_cnet), 10, i))
            p.start()
            processes.append(p)
예제 #4
0
        save_as_pickle(
            "dataset_cpu%i_%i_%s" %
            (cpu, idxx, datetime.datetime.today().strftime("%Y-%m-%d")),
            dataset_p)


if __name__ == "__main__":
    multiprocessing = 1
    if multiprocessing == 1:
        net_to_play = "c4_current_net_trained2_iter5.pth.tar"
        mp.set_start_method("spawn", force=True)
        net = ConnectNet()
        cuda = torch.cuda.is_available()
        if cuda:
            net.cuda()
        net.share_memory()
        net.eval()
        print("hi")
        # torch.save({'state_dict': net.state_dict()}, os.path.join("./model_data/",\
        #                                "c4_current_net.pth.tar"))

        current_net_filename = os.path.join("./model_data/",\
                                        net_to_play)
        checkpoint = torch.load(current_net_filename)
        net.load_state_dict(checkpoint['state_dict'])
        processes = []
        for i in range(6):
            p = mp.Process(target=MCTS_self_play, args=(net, 25, i))
            p.start()
            processes.append(p)
        for p in processes: