class NeuralNetworkTest(unittest.TestCase):

	def setUp(self):
		self.neuralNetwork = NeuralNetwork(learning_rate=0.15,n_hidden=2,momentum=0.95,activation='tanh')
		self.acceptanceEpsilon = 0.05
		self.seed = 1
		self.maxIterations = 11000

	def tearDown(self):
		del self.neuralNetwork
		del self.acceptanceEpsilon
		del self.seed
		del self.maxIterations

	def retrieveEstimationError(self,x,target):

		#setting number of inputs and number of outputs in the neural network
		_ , xColumns = x.shape
		_ , targetColumns = target.shape
		self.neuralNetwork.n_in = xColumns
		self.neuralNetwork.n_out = targetColumns

		self.neuralNetwork.initialize_weights()

		self.neuralNetwork.backpropagation(x,target,maxIterations=self.maxIterations)

		# Network result after training
		estimation = self.neuralNetwork.feed_forward(x)

		estimationError = EstimationError(estimatedValues=estimation,targetValues=target)
		estimationError.computeErrors()
		totalError = estimationError.getTotalError()
		return totalError

	def testXOR(self):
		numpy.random.seed(seed=self.seed)
		x = numpy.array([[0,0],
		      [0,1],
		      [1,0],
		      [1,1]])
		target = numpy.array([[0]
			  ,[1]
			  ,[1]
			  ,[0]])

		totalError = self.retrieveEstimationError(x,target)	
		print 'Error XOR:',totalError
		self.assertTrue(totalError<=self.acceptanceEpsilon)

	def testOR(self):
		numpy.random.seed(seed=self.seed)
		x = numpy.array([[0,0],
		      [0,1],
		      [1,0],
		      [1,1]])
		target = numpy.array([[0]
			  ,[1]
			  ,[1]
			  ,[1]])

		totalError = self.retrieveEstimationError(x,target)
		print 'Error OR:',totalError

		self.assertTrue(totalError<=self.acceptanceEpsilon)

	def testAND(self):
		numpy.random.seed(seed=self.seed)
		x = numpy.array([[0,0],
		      [0,1],
		      [1,0],
		      [1,1]])
		target = numpy.array([[0]
			  ,[0]
			  ,[0]
			  ,[1]])

		totalError = self.retrieveEstimationError(x,target)
		print 'Error AND:',totalError
		self.assertTrue(totalError<=self.acceptanceEpsilon)

	#test (x1 or x2) and x3
	def testORAND(self):
		numpy.random.seed(seed=self.seed)
		x = numpy.array([[0,0,0],
		      [0,0,1],
		      [0,1,0],
		      [1,0,0],
		      [0,1,1],
		      [1,1,0],
		      [1,0,1],
		      [1,1,1]])
		target = numpy.array([[0]
			  ,[0]
			  ,[0]
			  ,[0]
			  ,[1]
			  ,[0]
			  ,[1]
			  ,[1]])

		totalError = self.retrieveEstimationError(x,target)
		print 'Error ORAND:',totalError
		self.assertTrue(totalError<=self.acceptanceEpsilon)

		#test (x1 and x2) or x3
	def testANDOR(self):
		numpy.random.seed(seed=self.seed)
		x = numpy.array([[0,0,0],
		      [0,0,1],
		      [0,1,0],
		      [1,0,0],
		      [0,1,1],
		      [1,1,0],
		      [1,0,1],
		      [1,1,1]])
		target = numpy.array([[0]
			  ,[1]
			  ,[0]
			  ,[0]
			  ,[1]
			  ,[1]
			  ,[1]
			  ,[1]])

		totalError = self.retrieveEstimationError(x,target)
		print 'Error ANDOR:',totalError
		self.assertTrue(totalError<=self.acceptanceEpsilon)
		data = dataset.readlines()
		for entry in data:
			(x,y,area) = entry.split()
			x = float(x)
			y = float(y)
			area = int(area)
			if area==-1: area = 0
			x_input.append([x,y])
			target.append([area])

	x_input = numpy.array(x_input)
	target = numpy.array(target)

	numpy.random.seed(seed=1) #using fixed seed for testing purposes
	_ , xColumns = x_input.shape
	_ , targetColumns = target.shape
	n_hidden = 8
	momentum = 0
	neuralNetwork = NeuralNetwork(learning_rate=0.01,n_in=xColumns,n_hidden=n_hidden,n_out=targetColumns, momentum = momentum)

	neuralNetwork.initialize_weights()
	results_file = ''.join(['results_lr',str(neuralNetwork.learning_rate),'_m',str(momentum),'_',str(n_hidden),"hidden",file_name.rsplit('.', 1)[0]]+['.out'])

	neuralNetwork.backpropagation(x_input,target,maxIterations=10000, batch= False,file_name=results_file)
	network_file = ''.join(['trained_lr',str(neuralNetwork.learning_rate),'_m',str(momentum),'_',str(n_hidden),"hidden",file_name.rsplit('.', 1)[0]]+['.nn'])
	pickle.dump(neuralNetwork, file(network_file,'wb'))

	print neuralNetwork.feed_forward(x_input)
	nn2 = pickle.load(file(network_file,'rb'))
	print network_file
	print nn2.feed_forward(x_input)