data = []
for line in lines:
    values = line.split(',')
    # Divide by 15 for Min Max Normalisation
    input = [(int(x.replace('\n', '')) / 15) for x in values[1:]]
    label = charToOneHot(values[0])
    data.append([input, label])

dataSet = DataSet(data, 0.8)

# Load or insantiate new MLP
if len(sys.argv) == 1:
    mlp = MLP(dataSet.numInputs, numHidden, dataSet.numOutputs, outputActivation='SOFTMAX', hiddenActivation=hiddenActivation)
    train = True
elif sys.argv[1] == 'load':
    mlp = MLP(loadModel=True)
    train = False
    numEpochs = 1

# Train and test
for epoch in range(numEpochs):
    trainLoss, trainAccuracy = mlp.process(dataSet.trainData, train=train, learningRate=learningRate, updateFreq=updateFreq)
    if epoch % printFreq == 0 or epoch == numEpochs - 1:
        testLoss, testAccuracy = mlp.process(dataSet.testData, train=False)
        print("EPOCH:      ", epoch)
        print("TRAIN LOSS: ", trainLoss)
        print("TRAIN ACC:  ", trainAccuracy, "%")
        print("TEST LOSS:  ", testLoss)
        print("TEST ACC:   ", testAccuracy, "%\n")
Exemplo n.º 2
0
                    [[1, 1], [1, 0]]])

# Best learning rates for each determined through experimenting
if outputActivation == 'SIGMOID':
    dataSet = dataSet1
    learningRate = 1
if outputActivation == 'TANH':
    dataSet = dataSet2
    learningRate = 1
if outputActivation == 'RELU':
    dataSet = dataSet1
    learningRate = 0.01
if outputActivation == 'SOFTMAX':
    dataSet = dataSet3
    learningRate = 0.1
if 'dataSet' not in list(locals()) + list(globals()):
    print("\'" + outputActivation + "\' is not a valid activation function.")
    sys.exit()

mlp = MLP(dataSet.numInputs, 2, dataSet.numOutputs, outputActivation,
          hiddenActivation)

for epoch in range(numEpochs):
    epochLoss, epochAccuracy = mlp.process(dataSet.trainData,
                                           train=True,
                                           learningRate=learningRate)
    if epoch % printFreq == 0:
        print("EPOCH:      ", epoch)
        print("TRAIN LOSS: ", epochLoss)
        print("TRAIN ACC:  ", epochAccuracy, "%\n")