Esempio n. 1
0
class NNPlayer(Player):
    """ A manacala player uses a neural network to store its approximation function """
    
    LEGAL_STRATEGY = [
        'greedy',
        'random', 
        'weighted', 
        'exponential'
    ]
    
    def __init__(self, id, row=6, numStones=4):
        self.setID(id)
        
        self.learn = True
        self.alpha = 0.5
        self.discount = 0.9
        
        self.rowSize = row
        self.stones = numStones
        self.movelist = [[]] * 2  # two lists to allow for playing against self
        
        self.inputSize = 2+2*self.rowSize+1
        self.Q = NeuralNet(self.inputSize, 2 * self.inputSize) # set the hidden layer 2 times the input layer
        # if exploit, choose expected optimal move
        # otherwise, explore (randomize choice)
        self.strategy = "greedy"
        
        self.recentGames = []   # holds the transcripts of the most recent games
        self.numIterations = 1
        self.numRecent = 1      # number of games to track as recent
        
    def setID(self, id):
        """ set player identity """
        if id > 1 or id < 0:
            return False
        self.id = id
        return True
    
    def setLearning(self, toLearn):
        self.learn = toLearn
        
    def setDiscountFactor(self, discount):
        """ set discount factor """
        if discount > 1 or discount < 0:
            return False
        self.discount = discount
        return True
    
    def setStrategy(self, strategy):
        """ if given strategy is supported return true """
        if strategy in NNPlayer.LEGAL_STRATEGY:
            self.strategy = strategy
            return True
        return False
    
    def getMove(self, board):
        """ chooses next move """
        state = self._getState(board)
        qVals = self._getQvals(board)
        myside = board.mySide(self.id)
        validMoves = [index for index, val in enumerate(myside) if val > 0]
        
        # if there is no action available, just choose 0
        if len(validMoves) == 0: return -1
        # condense to only non-empty pits
        validQVals = []
        #for index, val in enumerate(validMoves):
#            validQVals[index] = qVals[val]
        for val in validMoves:
            validQVals.append(qVals[val - 1])
            
        # choose action based on strategy
        if self.strategy == NNPlayer.LEGAL_STRATEGY[0]: # greedy
            validMove = self._getBestIndex(validQVals)
        elif self.strategy == NNPlayer.LEGAL_STRATEGY[1]: # random
            validMove = self._getRandIndex(validQVals)
        elif self.strategy == NNPlayer.LEGAL_STRATEGY[2]:   # weighted
            validMove = self._getWeightedIndex(validQVals)
        elif self.strategy == NNPlayer.LEGAL_STRATEGY[3]:   #exponential
            validMove = self._getExponentialIndex(validQVals)
        else:   # greedy
            validMove = self._getBestIndex(validQVals)
        
        move = validMoves[validMove]
        self.movelist[self.id].append(Pair(state, move))
        return move
        
    def _getRandIndex(self, validQvals):
        """ chooses a move randomly with uniform distribution """
        return random.randint(len(validQvals))
    
    def _getWeightedIndex(self, validQvals):
        """ chooses a move randomly based on predicted Q values """
        validQvals = self._makePositive(validQvals)
        sumValue = sum(validQvals)
        arrow = random.random() * sumValue
        runningSum = 0
        for index, val in enumerate(validQvals):
            runningSum += val
            if runningSum >= arrow:
                return index
        return 0
    
    def _getExponentialIndex(self, validQvals):
        """ chooses a moove randomly based on the exponential of the Q values """
        validQvals = self._makePositive(validQvals)
        validQvals = self._getExponentialValues(validQvals)
        return self._getWeightedIndex(validQvals)
    
    def _getExponentialValues(self, arr):
        """ returns an array of the exponential of the values of the array """
        return [math.exp(val) for val in arr]
    
    def _makePositive(self, arr):
        """ if array has a negtive value, its abs value is added to
        all elements of the array; half the least postive value is then
        assigned for all zero values """
        minVal = min(arr)
        if minVal < 0:
            arr = self._addToArray(minVal, arr)
            minVal = self._getMinPos(arr)
            arr = self._addToZeros(minVal/2, arr)
        return arr
        
    def _getMinPos(self, arr):
        """ finds the minimum positive value in the array """
        min = sys.maxint
        found = False
        for i in arr:
            if i > 0 and i < min:
                min = i
                found = True
        # the minimum positive was found
        if found: return min
        # array has no positive values
        else: return 0
        
    def _addToZeros(self, num, arr):
        """ adds num to all zero values in the array """
        for index, val in enumerate(arr):
            if val == 0:
                arr[index] += num
        return arr
                
    def _addToArray(self, num, arr):
        """ adds the num to all values in the array """
        return [i + num for i in arr]
    
    def _getBestIndex(self, validQvals):
        """ chooses current expected best move """
        maxVal = max(validQvals) # FIXME
        bestMoves = [index for index, move in enumerate(validQvals) if move == maxVal]

        # heuristic: choose last bucket
        return int(bestMoves[-1])
    
    def _getQvals(self, board):
        """ retrieves the q values for all actions from the current state """
        state = self._getState(board)
        # create the input to neural network
        toNN = [state[i-1] for i in range(1, self.inputSize)]
        toNN.insert(0, 0.0)
        # find expected rewards
        qVals = []
        for i in range(self.rowSize):
            toNN[0] = float(i)
            qVals.append(self.Q.calculate(toNN))
        return qVals
        
    def _getState(self, board):
        """ constructs the state as a list """
        mySide = board.mySide(self.id)
        oppSide = board.oppSide(self.id)
        myMancala = board.stonesInMyMancala(self.id)
        oppMancala = board.stonesInOppMancala(self.id)
        
        state = [] # size should be inputSize - 1
        state.append(float(myMancala))
#        for i in range(self.rowSize):
#            state.append(mySide[i])
        for my in mySide:
            state.append(float(my))
        state.append(float(oppMancala))
#        for i in range(self.rowSize):
#            state.append(oppSide[i])
        for op in oppSide:
            state.append(float(op))
        return state
    
    def gameOver(self, myScore, oppScore):
        """ notifies learner that the game is over,
        update the Q function based on win or loss and the move list """
        if not self.learn:
            return
        
        reward = float(myScore) - float(oppScore)
        self.movelist[self.id].append(reward)
        self._updateGameRecord(self.movelist[self.id])
        self.movelist[self.id] = []
        self._learnFromGameRecord()
        
    def _learnFromGameRecord(self):
        for i in self.recentGames:
            self._learnFromGame(self.recentGames[i])
            
    def _learnFromGame(self, movelist):
        reward = float(movelist[-1])
        # update Q function
        sap = movelist[-2]
        state = sap.state
        action = sap.action
        
        example = []
        example.append(float(action))
        example.extend(state[:self.inputSize-1])
        
        oldQ = self.Q.calculate(example)
        newQ = float((1.0 - self.alpha) * oldQ + self.alpha * reward)
        self.Q.learnFromExample(example, newQ)
        
        nextState = state[:]
        nextAction = action
        nextExample = example[:]
        
        for i in range(3, len(movelist)):
            sap = movelist[len(movelist)-i]
            reward = sap.reward
            state = sap.state
            action = sap.action
            
            example = []
            example.append(float(action))
            example.extend(state[:self.inputSize-1])
        
            # find expected rewards
            qVals = []
            for i in range(self.rowSize):
                nextExample[0] = float(i)
                qVals[i] = self.Q.calculate(nextExample)
            
            maxVal = max(qVals)
            oldQ = self.Q.calculate(example)
            newQ = float((1.0 - self.alpha) * oldQ + self.alpha * (reward + self.discount * maxVal))
            self.Q.learnFromExample(example, newQ)
            
    def _updateGameRecord(self, moves):
        """ updates statistics """
        while len(self.recentGames) > self.numRecent:
            del self.recentGames[0]
        self.recentGames.extend(moves)
        
    def setNumRecent(self, recent):
        """ changes number of results to store as recent """
        self.numRecent = recent
        
    def setNumIterations(self, iters):
        self.numIterations = iters
        
    def saveToFile(self, filename, mode):
        self.Q.saveToFile(filename, mode)
        f = open(filename, mode)
        f.write(str(self.id)+"\n")
        f.write(str(self.rowSize)+"\n")
        f.write(str(self.stones)+"\n")
        f.write(str(self.inputSize)+"\n")
        f.write(self.strategy+"\n")
        f.write(str(self.learn)+"\n")
        f.write(str(self.alpha)+"\n")
        f.write(str(self.discount)+"\n")
        f.write(str(self.numIterations)+"\n")
        f.write(str(self.numRecent)+"\n")
        f.flush()
        f.close()
        
    def loadFromFile(self, filename):
        self.Q.loadFromFile(filename)
        f = open(filename, 'r')
        self.id = int(f.readline())
        self.rowSize = int(f.readline())
        self.stones = int(f.readline())
        self.strategy = f.readline().trim()
        self.learn = f.readline()
        self.alpha = float(f.readline())
        self.discount = float(f.readline())
        self.numIterations = int(f.readline())
        self.numRecent = int(f.readline())