def T8(): ''' Tests if multiple ANNRs can be created without affecting each other ''' A = np.random.rand(32, 4) Y = (A.sum(axis = 1) ** 2).reshape(-1, 1) m1 = ANNR([4], [('F', 4), ('AF', 'tanh'), ('F', 1)], maxIter = 16) m1.fit(A, Y) R1 = m1.GetWeightMatrix(0) m2 = ANNR([4], [('F', 4), ('AF', 'tanh'), ('F', 1)], maxIter = 16) m2.fit(A, Y) R2 = m1.GetWeightMatrix(0) if (R1 != R2).any(): return False return True
def T14(): ''' Tests saving and restore a model ''' A = np.random.rand(32, 4) Y = (A.sum(axis = 1) ** 2).reshape(-1, 1) m1 = ANNR([4], [('F', 4), ('AF', 'tanh'), ('F', 1)], maxIter = 16, name = 't12ann1') m1.fit(A, Y) m1.SaveModel('./t12ann1') R1 = m1.GetWeightMatrix(0) ANN.Reset() m1 = ANNR([4], [('F', 4), ('AF', 'tanh'), ('F', 1)], maxIter = 16, name = 't12ann2') m1.RestoreModel('./', 't12ann1') R2 = m1.GetWeightMatrix(0) if (R1 != R2).any(): return False return True