def evaluateModelAgainstTrainingData(model):
    inStates, valueLabels, policyLabels = MemoryBuffers.getAllTrainingData()
    _, valueLoss, policyLoss = model.evaluate(
        np.array(inStates),
        [np.array(valueLabels), np.array(policyLabels)],
        verbose=2,
        shuffle=True)
    print("ValueLoss: {}  PolicyLoss: {}".format(valueLoss, policyLoss))
Exemple #2
0
def sendToOverlord(overlordConnection, localPipe, amountOfWorkers, endPipe):
    # Needed in the end when we wish to count the bitmaps
    import time
    time.sleep(3)
    print("Starting init")
    import StartInit
    StartInit.init()

    runningCycle = True
    amountOfCollectedGames = 0
    amountOfCollectedWorkers = 0
    collectedVisitedStates = []

    while (runningCycle):
        tupleMsg = localPipe.get()
        msgType = tupleMsg[0]

        if (msgType == C.LocalWorkerProtocol.DUMP_TO_REPLAY_BUFFER):
            _, amountOfGames, states, evals, polices, weights = tupleMsg
            MemoryBuffers.addLabelsToReplayBuffer(states, evals, polices)
            amountOfCollectedGames += amountOfGames

            if (amountOfCollectedGames >=
                    MachineSpecificSettings.GAMES_BATCH_SIZE_TO_OVERLORD):
                print("Sending to oracle from dataworker")
                dStates, dEvals, dPolices, dWeights = MemoryBuffers.getAllTrainingData(
                )
                dumpMsg = (amountOfCollectedGames, dStates, dEvals, dPolices,
                           dWeights)
                overlordConnection.sendMessage(
                    C.RemoteProtocol.DUMP_REPLAY_DATA_TO_OVERLORD, dumpMsg)

                amountOfCollectedGames = 0
                MemoryBuffers.clearReplayBuffer()

        elif (msgType == C.LocalWorkerProtocol.DUMP_MOST_VISITED_STATES):
            amountOfCollectedWorkers += 1
            _, states = tupleMsg

            if (amountOfCollectedWorkers >= amountOfWorkers):
                print("collected states from all local workers: ",
                      len(collectedVisitedStates))
                sendMostVisitedStatesToOverlord(overlordConnection,
                                                collectedVisitedStates)
                print("Sent message to all workers")
                runningCycle = False

    endPipe.put("Ending by datamanager")
    print("Ending sending thread")