Beispiel #1
0
def T8():
    '''
    Tests if multiple ANNRs can be created without affecting each other
    '''
    A = np.random.rand(32, 4)
    Y = (A.sum(axis = 1) ** 2).reshape(-1, 1)
    m1 = ANNR([4], [('F', 4), ('AF', 'tanh'), ('F', 1)], maxIter = 16)
    m1.fit(A, Y)
    R1 = m1.GetWeightMatrix(0)
    m2 = ANNR([4], [('F', 4), ('AF', 'tanh'), ('F', 1)], maxIter = 16)
    m2.fit(A, Y)
    R2 = m1.GetWeightMatrix(0)
    if (R1 != R2).any():
        return False
    return True
Beispiel #2
0
def T14():
    '''
    Tests saving and restore a model
    '''
    A = np.random.rand(32, 4)
    Y = (A.sum(axis = 1) ** 2).reshape(-1, 1)
    m1 = ANNR([4], [('F', 4), ('AF', 'tanh'), ('F', 1)], maxIter = 16, name = 't12ann1')
    m1.fit(A, Y)
    m1.SaveModel('./t12ann1')
    R1 = m1.GetWeightMatrix(0)
    ANN.Reset()
    m1 = ANNR([4], [('F', 4), ('AF', 'tanh'), ('F', 1)], maxIter = 16, name = 't12ann2')
    m1.RestoreModel('./', 't12ann1')
    R2 = m1.GetWeightMatrix(0)
    if (R1 != R2).any():
        return False
    return True