示例#1
0
				optim_config={
				  'learning_rate': default_para['learning_rate'],
				  'momentum': default_para['momentum']
				},
				lr_decay=default_para['lr_decay'],
				num_epochs=default_para['num_epochs'], 
				batch_size=default_para['batch_size'] ,
				print_every=200,
				verbose = True)
solver.train()

best_val_acc = solver.best_val_acc


#Calculate the Confusion Matrix, F1 and print out them
valid_confu_mat, valid_F1 = solver.get_conf_mat_F1(type='validation')
test_confu_mat, test_F1 = solver.get_conf_mat_F1(type='test')
print('\n')
print('**************The best validation data Accuracy is: ', best_val_acc, '***********************')
print('**************The test data Accuracy is: ', solver.get_test_accu(), '***********************')
print('F1 value for Validation: ', valid_F1)
print('Confusion Matrix for Validation: \n')
display_conf(valid_confu_mat)
print('F1 value for Test: ', test_F1)
print('Confusion Matrix for Test: \n')
display_conf(test_confu_mat)
print('\n\n\n')

#Plot graph: Loss VS. Iteration, Accuracy VS. Epoch
plt.subplot(2,1,1)
plt.title('Trainingloss')
						print_every=200,
						verbose = True)
		solver.train()

		best_val_acc = solver.best_val_acc

		#For 'learning_rate' and 'regularization', the x-axis is in log-space
		if (para_type == 'learning_rate' or para_type =='regularization') and value != 0:
			plot_para_value.append(np.log10(value))
		else:
			plot_para_value.append(value)

		plot_val_acc.append(best_val_acc)

		#Calculate the Confusion Matrix, F1 and print out them
		confu_mat, F1 = solver.get_conf_mat_F1(type='validation')
		print('\n')
		print('**************The best validation data Accuracy is: ', best_val_acc, '***********************')
		print('F1 value: ', F1)
		print('Confusion Matrix: \n')
		display_conf(confu_mat)
		print('\n\n\n')

		#Plot graph: Loss VS. Iteration, Accuracy VS. Epoch
		plt.subplot(2,1,1)
		plt.title('Trainingloss')
		plt.plot(solver.loss_history,'o')
		plt.xlabel('Iteration')
		plt.subplot(2,1,2)
		plt.title('Accuracy')
		plt.plot(solver.train_acc_history,'-o',label='train')