示例#1
0
def hist_compare(game,name='dense',num_games=100):      #currently string default is 'dense'

    wins_list = []
    losses_list = []
    draws_list = []
    for net_no in range(1,11):
        if net_no == 1:
            player2 = Nets.Conv()
            player2.eval()                                                              # dense or conv
        else:
            player2 = Nets.Conv()                                                              # dense or conv
            player2.load_state_dict(torch.load(name+'/'+name+str(int(net_no-1))+'.pth'))        # dense or conv
            player2.eval()
        player1 = Nets.Conv()                                                                  # dense or conv
        player1.load_state_dict(torch.load(name+'/'+name+str(int(net_no))+'.pth'))              # dense or conv
        player1.eval()
        print('Pitting',net_no-1,'and',net_no)
        wins,draws,losses = pit_against_network(player1,player2,game,num_games)
        print('wins:',wins)
        print('draws:',draws)
        print('losses:',losses)
        wins_list.append(wins)
        draws_list.append(draws)
        losses_list.append(losses)
    
    plt.plot(wins_list,color='green')
    plt.plot(losses_list,color='red',alpha=0.3)
    plt.savefig('HistComp'+name+'.png')
示例#2
0
import numpy as np
import torch

import Nets
from eval import pit_against_network, pit_human, hist_compare, pit_onelookahead
from Connect4Game import Connect4Game

if __name__ == '__main__':

    # initialize networks
    net = Nets.Conv()
    net.load_state_dict(torch.load('convbkup/conv10.pth'))
    net.eval()
    aux_net = Nets.Conv()
    #aux_net.load_state_dict(torch.load('convbkup/conv1.pth'))
    aux_net.eval()

    # initialize game
    game = Connect4Game()
    """ wins,draws,losses = pit_against_network(aux_net,net,game,100)

    print('Wins:',wins)
    print('Losses:',losses)
    print('Draws:',draws) """

    #pit_human(aux_net,game)
    hist_compare(game, name='conv', num_games=100)
    """ wins,draws,losses = pit_onelookahead(game,aux_net,1)

    print('Wins:',wins)
    print('Losses:',losses)