示例#1
0
def train_samm_cross(batch_size, spatial_epochs, temporal_epochs, train_id, dB, spatial_size, flag, tensorboard):
	############## Path Preparation ######################
	root_db_path = "/media/ice/OS/Datasets/"
	workplace = root_db_path + dB + "/"
	inputDir = root_db_path + dB + "/" + dB + "/" 
	######################################################
	classes = 5
	if dB == 'CASME2_TIM':
		table = loading_casme_table(workplace + 'CASME2-ObjectiveClasses.xlsx')
		listOfIgnoredSamples, IgnoredSamples_index = ignore_casme_samples(inputDir)

		############## Variables ###################
		r = w = spatial_size
		subjects=2
		n_exp = 5
		# VidPerSubject = get_subfolders_num(inputDir, IgnoredSamples_index)
		listOfIgnoredSamples = []
		VidPerSubject = [2,1]
		timesteps_TIM = 10
		data_dim = r * w
		pad_sequence = 10
		channel = 3
		############################################		

		os.remove(workplace + "Classification/CASME2_TIM_label.txt")



	elif dB == 'CASME2_Optical':
		table = loading_casme_table(workplace + 'CASME2-ObjectiveClasses.xlsx')
		listOfIgnoredSamples, IgnoredSamples_index = ignore_casme_samples(inputDir)

		############## Variables ###################
		r = w = spatial_size
		subjects=26
		n_exp = 5
		VidPerSubject = get_subfolders_num(inputDir, IgnoredSamples_index)
		timesteps_TIM = 9
		data_dim = r * w
		pad_sequence = 9
		channel = 3
		############################################		

		# os.remove(workplace + "Classification/CASME2_TIM_label.txt")


	elif dB == 'SAMM_TIM10':
		table, table_objective = loading_samm_table(root_db_path, dB)
		listOfIgnoredSamples = []
		IgnoredSamples_index = np.empty([0])

		################# Variables #############################
		r = w = spatial_size
		subjects = 29
		n_exp = 8
		VidPerSubject = get_subfolders_num(inputDir, IgnoredSamples_index)
		timesteps_TIM = 10
		data_dim = r * w
		pad_sequence = 10
		channel = 3
		classes = 8
		#########################################################		

	elif dB == 'SAMM_CASME_Optical':
		# total amount of videos 253
		# table, table_objective = loading_samm_table(root_db_path, dB)
		# table = table_objective
		table = loading_casme_objective_table(root_db_path, dB)

		# merge samm and casme tables
		# table = np.concatenate((table, table2), axis=1)
		
		# print(table.shape)

		# listOfIgnoredSamples, IgnoredSamples_index, sub_items = ignore_casme_samples(inputDir)
		listOfIgnoredSamples = []
		IgnoredSamples_index = np.empty([0])
		sub_items = np.empty([0])
		list_samples = filter_objective_samples(table)

		r = w = spatial_size
		subjects = 26 # some subjects were removed because of objective classes and ignore samples: 47
		n_exp = 5
		# TODO:
		# 1) Further decrease the video amount, the one with objective classes >= 6
		# list samples: samples with wanted objective class
		VidPerSubject, list_samples = get_subfolders_num_crossdb(inputDir, IgnoredSamples_index, sub_items, table, list_samples)

		# print(VidPerSubject)
		# print(len(VidPerSubject))
		# print(sum(VidPerSubject))
		timesteps_TIM = 9
		data_dim = r * w
		channel = 3
		classes = 5
		if os.path.isfile(workplace + "Classification/SAMM_CASME_Optical_label.txt"):
			os.remove(workplace + "Classification/SAMM_CASME_Optical_label.txt")
		##################### Variables ######################

		######################################################

	############## Flags ####################
	tensorboard_flag = tensorboard
	resizedFlag = 1
	train_spatial_flag = 0
	train_temporal_flag = 0
	svm_flag = 0
	finetuning_flag = 0
	cam_visualizer_flag = 0
	channel_flag = 0

	if flag == 'st':
		train_spatial_flag = 1
		train_temporal_flag = 1
		finetuning_flag = 1
	elif flag == 's':
		train_spatial_flag = 1
		finetuning_flag = 1
	elif flag == 't':
		train_temporal_flag = 1
	elif flag == 'nofine':
		svm_flag = 1
	elif flag == 'scratch':
		train_spatial_flag = 1
		train_temporal_flag = 1
	elif flag == 'st4':
		train_spatial_flag = 1
		train_temporal_flag = 1
		channel_flag = 1
	elif flag == 'st7':
		train_spatial_flag = 1
		train_temporal_flag = 1
		channel_flag = 2
	#########################################



	############ Reading Images and Labels ################

	SubperdB = Read_Input_Images_SAMM_CASME(inputDir, list_samples, listOfIgnoredSamples, dB, resizedFlag, table, workplace, spatial_size, channel)
	print("Loaded Images into the tray...")
	labelperSub = label_matching(workplace, dB, subjects, VidPerSubject)
	print("Loaded Labels into the tray...")



	if channel_flag == 1:
		SubperdB_strain = Read_Input_Images(inputDir, listOfIgnoredSamples, 'CASME2_Strain_TIM10', resizedFlag, table, workplace, spatial_size, 1)
	elif channel_flag == 2:
		SubperdB_strain = Read_Input_Images(inputDir, listOfIgnoredSamples, 'CASME2_Strain_TIM10', resizedFlag, table, workplace, spatial_size, 1)
		SubperdB_gray = Read_Input_Images(inputDir, listOfIgnoredSamples, 'CASME2_TIM', resizedFlag, table, workplace, spatial_size, 3)		
	#######################################################


	########### Model Configurations #######################
	sgd = optimizers.SGD(lr=0.0001, decay=1e-7, momentum=0.9, nesterov=True)
	adam = optimizers.Adam(lr=0.00001, decay=0.000001)
	adam2 = optimizers.Adam(lr= 0.00075, decay= 0.0001)

	# Different Conditions for Temporal Learning ONLY
	if train_spatial_flag == 0 and train_temporal_flag == 1 and dB != 'CASME2_Optical':
		data_dim = spatial_size * spatial_size
	elif train_spatial_flag == 0 and train_temporal_flag == 1 and dB == 'CASME2_Optical':
		data_dim = spatial_size * spatial_size * 3
	else:
		data_dim = 4096

	########################################################


	########### Training Process ############


	# total confusion matrix to be used in the computation of f1 score
	tot_mat = np.zeros((n_exp,n_exp))

	# model checkpoint
	spatial_weights_name = 'vgg_spatial_'+ str(train_id) + '_casme2_'
	temporal_weights_name = 'temporal_ID_' + str(train_id) + '_casme2_'
	history = LossHistory()
	stopping = EarlyStopping(monitor='loss', min_delta = 0, mode = 'min')




	############### Reinitialization & weights reset of models ########################

	vgg_model_cam = VGG_16(spatial_size=spatial_size, classes=classes, weights_path='VGG_Face_Deep_16.h5')

	temporal_model = temporal_module(data_dim=data_dim, classes=classes, timesteps_TIM=timesteps_TIM)
	temporal_model.compile(loss='categorical_crossentropy', optimizer=adam, metrics=[metrics.categorical_accuracy])

	conv_ae = convolutional_autoencoder(spatial_size = spatial_size, classes = classes)
	conv_ae.compile(loss='binary_crossentropy', optimizer=adam)

	if channel_flag == 1 or channel_flag == 2:
		vgg_model = VGG_16_4_channels(classes=classes, spatial_size = spatial_size)
		vgg_model.compile(loss='categorical_crossentropy', optimizer=adam, metrics=[metrics.categorical_accuracy])
	else:
		vgg_model = VGG_16(spatial_size = spatial_size, classes=classes, weights_path='VGG_Face_Deep_16.h5')
		vgg_model.compile(loss='categorical_crossentropy', optimizer=adam, metrics=[metrics.categorical_accuracy])

	svm_classifier = SVC(kernel='linear', C=1)
		####################################################################################
		
		
		############ for tensorboard ###############
	if tensorboard_flag == 1:
		cat_path = tensorboard_path + str(sub) + "/"
		os.mkdir(cat_path)
		tbCallBack = keras.callbacks.TensorBoard(log_dir=cat_path, write_graph=True)

		cat_path2 = tensorboard_path + str(sub) + "spat/"
		os.mkdir(cat_path2)
		tbCallBack2 = keras.callbacks.TensorBoard(log_dir=cat_path2, write_graph=True)
		#############################################

	image_label_mapping = np.empty([0])


	Train_X, Train_Y= standard_data_loader(SubperdB, labelperSub, subjects, classes)

	# Rearrange Training labels into a vector of images, breaking sequence
	Train_X_spatial = Train_X.reshape(Train_X.shape[0]*timesteps_TIM, r, w, channel)
	# Test_X_spatial = Test_X.reshape(Test_X.shape[0]* timesteps_TIM, r, w, channel)

		
	# Special Loading for 4-Channel
	if channel_flag == 1:
		Train_X_Strain, _, Test_X_Strain, _, _ = data_loader_with_LOSO(sub, SubperdB_strain, labelperSub, subjects, classes)
		Train_X_Strain = Train_X_Strain.reshape(Train_X_Strain.shape[0]*timesteps_TIM, r, w, 1)
		Test_X_Strain = Test_X_Strain.reshape(Test_X.shape[0]*timesteps_TIM, r, w, 1)
		
		# Concatenate Train X & Train_X_Strain
		Train_X_spatial = np.concatenate((Train_X_spatial, Train_X_Strain), axis=3)
		Test_X_spatial = np.concatenate((Test_X_spatial, Test_X_Strain), axis=3)

		channel = 4

	elif channel_flag == 2:
		Train_X_Strain, _, Test_X_Strain, _, _ = data_loader_with_LOSO(sub, SubperdB_strain, labelperSub, subjects, classes)
		Train_X_gray, _, Test_X_gray, _, _ = data_loader_with_LOSO(sub, SubperdB_gray, labelperSub, subjects)
		Train_X_Strain = Train_X_Strain.reshape(Train_X_Strain.shape[0]*timesteps_TIM, r, w, 1)
		Test_X_Strain = Test_X_Strain.reshape(Test_X_Strain.shape[0]*timesteps_TIM, r, w, 1)
		Train_X_gray = Train_X_gray.reshape(Train_X_gray.shape[0]*timesteps_TIM, r, w, 3)
		Test_X_gray = Test_X_gray.reshape(Test_X_gray.shape[0]*timesteps_TIM, r, w, 3)

		# Concatenate Train_X_Strain & Train_X & Train_X_gray
		Train_X_spatial = np.concatenate((Train_X_spatial, Train_X_Strain, Train_X_gray), axis=3)
		Test_X_spatial = np.concatenate((Test_X_spatial, Test_X_Strain, Test_X_gray), axis=3)	
		channel = 7		
		
	if channel == 1:
		# Duplicate channel of input image
		Train_X_spatial = duplicate_channel(Train_X_spatial)
		# Test_X_spatial = duplicate_channel(Test_X_spatial)

	# Extend Y labels 10 fold, so that all images have labels
	Train_Y_spatial = np.repeat(Train_Y, timesteps_TIM, axis=0)
	# Test_Y_spatial = np.repeat(Test_Y, timesteps_TIM, axis=0)		


	# print ("Train_X_shape: " + str(np.shape(Train_X_spatial)))
	# print ("Train_Y_shape: " + str(np.shape(Train_Y_spatial)))
	# print ("Test_X_shape: " + str(np.shape(Test_X_spatial)))	
	# print ("Test_Y_shape: " + str(np.shape(Test_Y_spatial)))	
	# print(Train_X_spatial)
	##################### Training & Testing #########################

	X = Train_X_spatial.reshape(Train_X_spatial.shape[0], channel, r, w)
	y = Train_Y_spatial.reshape(Train_Y_spatial.shape[0], classes)
	normalized_X = X.astype('float32') / 255.

	# test_X = Test_X_spatial.reshape(Test_X_spatial.shape[0], channel, r, w)
	# test_y = Test_Y_spatial.reshape(Test_Y_spatial.shape[0], classes)
	# normalized_test_X = test_X.astype('float32') / 255.

	print(X.shape)

	###### conv weights must be freezed for transfer learning ######
	if finetuning_flag == 1:
		for layer in vgg_model.layers[:33]:
			layer.trainable = False
		for layer in vgg_model_cam.layers[:31]:
			layer.trainable = False

	if train_spatial_flag == 1 and train_temporal_flag == 1:
		# Autoencoder features
		# conv_ae.fit(normalized_X, normalized_X, batch_size=batch_size, epochs=spatial_epochs, shuffle=True)
	
		# Spatial Training
		if tensorboard_flag == 1:
			vgg_model.fit(X, y, batch_size=batch_size, epochs=spatial_epochs, shuffle=True, callbacks=[tbCallBack2])
		else:
			vgg_model.fit(X, y, batch_size=batch_size, epochs=spatial_epochs, shuffle=True, callbacks=[history, stopping])

			
		# record f1 and loss
		file_loss = open(workplace+'Classification/'+ 'Result/'+dB+'/loss_' + str(train_id) +  '.txt', 'a')
		file_loss.write(str(history.losses) + "\n")
		file_loss.close()

		file_loss = open(workplace+'Classification/'+ 'Result/'+dB+'/accuracy_' + str(train_id) +  '.txt', 'a')
		file_loss.write(str(history.accuracy) + "\n")
		file_loss.close()			

		vgg_model.save_weights(spatial_weights_name + 'HDE'+ ".h5")
		model = Model(inputs=vgg_model.input, outputs=vgg_model.layers[35].output)
		plot_model(model, to_file="spatial_module_FULL_TRAINING.png", show_shapes=True)	

		model_ae = Model(inputs=conv_ae.input, outputs=conv_ae.output)
		plot_model(model_ae, to_file='autoencoders.png', show_shapes=True)

		# Autoencoding
		output_ae = model_ae.predict(normalized_X, batch_size = batch_size)

		for i in range(batch_size):
			visual_ae = output_ae[i].reshape(224,224,channel)
			# de-normalize
			visual_ae = ( ( visual_ae - min(visual_ae) ) / ( max(visual_ae) - min(visual_ae) ) ) * 255
			fname = '{prefix}_{index}_{hash}.{format}'.format(prefix='AE_output', index=str(sub),
				 												hash=np.random.randint(1e7), format='png')
			cv2.imwrite(workplace+'Classification/Result/ae_train/'+fname, visual_ae)
				
		output_ae = model.predict(output_ae, batch_size = batch_size)


		# Spatial Encoding
		output = model.predict(X, batch_size = batch_size)
		# features = output.reshape(int(Train_X.shape[0]), timesteps_TIM, output.shape[1])

		# merging autoencoded features and spatial features
		output = np.concatenate((output, output_ae), axis=1)
		# print(output.shape)
		features = output.reshape(int(Train_X.shape[0]), timesteps_TIM, output.shape[1])
			
		# Temporal Training
		if tensorboard_flag == 1:
			temporal_model.fit(features, Train_Y, batch_size=batch_size, epochs=temporal_epochs, callbacks=[tbCallBack])
		else:
			temporal_model.fit(features, Train_Y, batch_size=batch_size, epochs=temporal_epochs)	

		temporal_model.save_weights(temporal_weights_name + 'HDE' + ".h5")

	# 	# Testing
	# 	output = model.predict(test_X, batch_size = batch_size)
	# 	output_ae = model_ae.predict(normalized_test_X, batch_size = batch_size)
	# 	for i in range(batch_size):
	# 		visual_ae = output_ae[i].reshape(224,224,channel)
	# 		# de-normalize
	# 		visual_ae = ( ( visual_ae - min(visual_ae) ) / ( max(visual_ae) - min(visual_ae) ) ) * 255
	# 		fname = '{prefix}_{index}_{hash}.{format}'.format(prefix='AE_output', index=str(sub),
	# 			 												hash=np.random.randint(1e7), format='png')
	# 		cv2.imwrite(workplace+'Classification/Result/ae_train/'+fname, visual_ae)

	# 	output_ae = model.predict(output_ae, batch_size = batch_size)
	# 	output = np.concatenate((output, output_ae), axis=1)
	# 	features = output.reshape(Test_X.shape[0], timesteps_TIM, output.shape[1])

	# 	predict = temporal_model.predict_classes(features, batch_size=batch_size)

	# 	##############################################################

	# #################### Confusion Matrix Construction #############
	# print (predict)
	# print (Test_Y_gt)	

	# ct = confusion_matrix(Test_Y_gt,predict)
	# # check the order of the CT
	# order = np.unique(np.concatenate((predict,Test_Y_gt)))
		
	# # create an array to hold the CT for each CV
	# mat = np.zeros((n_exp,n_exp))
	# # put the order accordingly, in order to form the overall ConfusionMat
	# for m in range(len(order)):
	# 	for n in range(len(order)):
	# 		mat[int(order[m]),int(order[n])]=ct[m,n]
			   
	# tot_mat = mat + tot_mat
	# ################################################################
	
	# #################### cumulative f1 plotting ######################
	# microAcc = np.trace(tot_mat) / np.sum(tot_mat)
	# [f1,precision,recall] = fpr(tot_mat,n_exp)


	# file = open(workplace+'Classification/'+ 'Result/'+dB+'/f1_' + str(train_id) +  '.txt', 'a')
	# file.write(str(f1) + "\n")
	# file.close()
	##################################################################

	################# write each CT of each CV into .txt file #####################
	# record_scores(workplace, dB, ct, sub, order, tot_mat, n_exp, subjects)
	###############################################################################
示例#2
0
def test_samm(batch_size, spatial_epochs, temporal_epochs, train_id, dB,
              spatial_size, flag, tensorboard):
    ############## Path Preparation ######################
    root_db_path = "/media/viprlab/01D31FFEF66D5170/"
    workplace = root_db_path + dB + "/"
    inputDir = root_db_path + dB + "/" + dB + "/"
    ######################################################
    classes = 5
    if dB == 'CASME2_TIM':
        table = loading_casme_table(workplace + 'CASME2-ObjectiveClasses.xlsx')
        listOfIgnoredSamples, IgnoredSamples_index = ignore_casme_samples(
            inputDir)

        ############## Variables ###################
        r = w = spatial_size
        subjects = 2
        n_exp = 5
        # VidPerSubject = get_subfolders_num(inputDir, IgnoredSamples_index)
        listOfIgnoredSamples = []
        VidPerSubject = [2, 1]
        timesteps_TIM = 10
        data_dim = r * w
        pad_sequence = 10
        channel = 3
        ############################################

        os.remove(workplace + "Classification/CASME2_TIM_label.txt")

    elif dB == 'CASME2_Optical':
        table = loading_casme_table(workplace + 'CASME2-ObjectiveClasses.xlsx')
        listOfIgnoredSamples, IgnoredSamples_index = ignore_casme_samples(
            inputDir)

        ############## Variables ###################
        r = w = spatial_size
        subjects = 26
        n_exp = 5
        VidPerSubject = get_subfolders_num(inputDir, IgnoredSamples_index)
        timesteps_TIM = 9
        data_dim = r * w
        pad_sequence = 9
        channel = 3
        ############################################

        # os.remove(workplace + "Classification/CASME2_TIM_label.txt")

    elif dB == 'SAMM_TIM10':
        table, table_objective = loading_samm_table(root_db_path, dB)
        listOfIgnoredSamples = []
        IgnoredSamples_index = np.empty([0])

        ################# Variables #############################
        r = w = spatial_size
        subjects = 29
        n_exp = 8
        VidPerSubject = get_subfolders_num(inputDir, IgnoredSamples_index)
        timesteps_TIM = 10
        data_dim = r * w
        pad_sequence = 10
        channel = 3
        classes = 8
        #########################################################

    elif dB == 'SAMM_CASME_Optical':
        # total amount of videos 253
        table, table_objective = loading_samm_table(root_db_path, dB)
        table = table_objective
        table2 = loading_casme_objective_table(root_db_path, dB)

        # merge samm and casme tables
        table = np.concatenate((table, table2), axis=1)

        # print(table.shape)

        # listOfIgnoredSamples, IgnoredSamples_index, sub_items = ignore_casme_samples(inputDir)
        listOfIgnoredSamples = []
        IgnoredSamples_index = np.empty([0])
        sub_items = np.empty([0])
        list_samples = filter_objective_samples(table)

        r = w = spatial_size
        subjects = 47  # some subjects were removed because of objective classes and ignore samples: 47
        n_exp = 5
        # TODO:
        # 1) Further decrease the video amount, the one with objective classes >= 6
        # list samples: samples with wanted objective class
        VidPerSubject, list_samples = get_subfolders_num_crossdb(
            inputDir, IgnoredSamples_index, sub_items, table, list_samples)

        # print(VidPerSubject)
        # print(len(VidPerSubject))
        # print(sum(VidPerSubject))
        timesteps_TIM = 9
        data_dim = r * w
        channel = 3
        classes = 5
        if os.path.isfile(workplace +
                          "Classification/SAMM_CASME_Optical_label.txt"):
            os.remove(workplace +
                      "Classification/SAMM_CASME_Optical_label.txt")
        ##################### Variables ######################

        ######################################################

    ############## Flags ####################
    tensorboard_flag = tensorboard
    resizedFlag = 1
    train_spatial_flag = 0
    train_temporal_flag = 0
    svm_flag = 0
    finetuning_flag = 0
    cam_visualizer_flag = 0
    channel_flag = 0

    if flag == 'st':
        train_spatial_flag = 1
        train_temporal_flag = 1
        finetuning_flag = 1
    elif flag == 's':
        train_spatial_flag = 1
        finetuning_flag = 1
    elif flag == 't':
        train_temporal_flag = 1
    elif flag == 'nofine':
        svm_flag = 1
    elif flag == 'scratch':
        train_spatial_flag = 1
        train_temporal_flag = 1
    elif flag == 'st4':
        train_spatial_flag = 1
        train_temporal_flag = 1
        channel_flag = 1
    elif flag == 'st7':
        train_spatial_flag = 1
        train_temporal_flag = 1
        channel_flag = 2
    #########################################

    ############ Reading Images and Labels ################

    SubperdB = Read_Input_Images_SAMM_CASME(inputDir, list_samples,
                                            listOfIgnoredSamples, dB,
                                            resizedFlag, table, workplace,
                                            spatial_size, channel)
    print("Loaded Images into the tray...")
    labelperSub = label_matching(workplace, dB, subjects, VidPerSubject)
    print("Loaded Labels into the tray...")

    if channel_flag == 1:
        SubperdB_strain = Read_Input_Images_SAMM_CASME(
            inputDir, list_samples, listOfIgnoredSamples, 'SAMM_CASME_Strain',
            resizedFlag, table, workplace, spatial_size, 1)

    elif channel_flag == 2:
        SubperdB_strain = Read_Input_Images(inputDir, listOfIgnoredSamples,
                                            'CASME2_Strain_TIM10', resizedFlag,
                                            table, workplace, spatial_size, 1)
        SubperdB_gray = Read_Input_Images(inputDir, listOfIgnoredSamples,
                                          'CASME2_TIM', resizedFlag, table,
                                          workplace, spatial_size, 3)
    #######################################################

    ########### Model Configurations #######################
    sgd = optimizers.SGD(lr=0.0001, decay=1e-7, momentum=0.9, nesterov=True)
    adam = optimizers.Adam(lr=0.00001, decay=0.000001)
    adam2 = optimizers.Adam(lr=0.00075, decay=0.0001)

    # Different Conditions for Temporal Learning ONLY
    if train_spatial_flag == 0 and train_temporal_flag == 1 and dB != 'CASME2_Optical':
        data_dim = spatial_size * spatial_size
    elif train_spatial_flag == 0 and train_temporal_flag == 1 and dB == 'CASME2_Optical':
        data_dim = spatial_size * spatial_size * 3
    else:
        data_dim = 4096

    ########################################################

    ########### Training Process ############

    # total confusion matrix to be used in the computation of f1 score
    tot_mat = np.zeros((n_exp, n_exp))

    # model checkpoint
    spatial_weights_name = 'vgg_spatial_' + str(train_id) + '_casme2_'
    temporal_weights_name = 'temporal_ID_' + str(train_id) + '_casme2_'
    history = LossHistory()
    stopping = EarlyStopping(monitor='loss',
                             min_delta=0,
                             mode='min',
                             patience=5)

    # model checkpoint
    if os.path.isdir('/media/viprlab/01D31FFEF66D5170/Weights/' +
                     str(train_id)) == False:
        os.mkdir('/media/viprlab/01D31FFEF66D5170/Weights/' + str(train_id))

    for sub in range(subjects):
        # sub = sub + 25
        # if sub > subjects:
        # 	break
        ############### Reinitialization & weights reset of models ########################
        spatial_weights_name = '/media/viprlab/01D31FFEF66D5170/Weights/' + str(
            train_id) + '/vgg_spatial_' + str(train_id) + '_' + str(
                dB) + '_' + str(sub) + '.h5'
        spatial_weights_name_strain = '/media/viprlab/01D31FFEF66D5170/Weights/' + str(
            train_id) + '/vgg_spatial_strain_' + str(train_id) + '_' + str(
                dB) + '_' + str(sub) + '.h5'

        temporal_weights_name = '/media/viprlab/01D31FFEF66D5170/Weights/' + str(
            train_id) + '/temporal_ID_' + str(train_id) + '_' + str(
                dB) + '_' + str(sub) + '.h5'

        ae_weights_name = '/media/viprlab/01D31FFEF66D5170/Weights/' + str(
            train_id) + '/autoencoder_' + str(train_id) + '_' + str(
                dB) + '_' + str(sub) + '.h5'
        ae_weights_name_strain = '/media/viprlab/01D31FFEF66D5170/Weights/' + str(
            train_id) + '/autoencoder_strain_' + str(train_id) + '_' + str(
                dB) + '_' + str(sub) + '.h5'

        temporal_model = temporal_module(data_dim=data_dim,
                                         timesteps_TIM=timesteps_TIM,
                                         weights_path=temporal_weights_name)
        temporal_model.compile(loss='categorical_crossentropy',
                               optimizer=adam,
                               metrics=[metrics.categorical_accuracy])

        if channel_flag == 1 or channel_flag == 2:
            vgg_model = VGG_16_4_channels(spatial_size=spatial_size,
                                          weights_path=spatial_weights_name)
            vgg_model.compile(loss='categorical_crossentropy',
                              optimizer=adam,
                              metrics=[metrics.categorical_accuracy])
        else:
            vgg_model = VGG_16(spatial_size=spatial_size,
                               weights_path='VGG_Face_Deep_16.h5')
            vgg_model.compile(loss='categorical_crossentropy',
                              optimizer=adam,
                              metrics=[metrics.categorical_accuracy])

        svm_classifier = SVC(kernel='linear', C=1)
        ####################################################################################

        Train_X, Train_Y, Test_X, Test_Y, Test_Y_gt = data_loader_with_LOSO(
            sub, SubperdB, labelperSub, subjects)

        # Rearrange Training labels into a vector of images, breaking sequence
        Train_X_spatial = Train_X.reshape(Train_X.shape[0] * timesteps_TIM, r,
                                          w, channel)
        Test_X_spatial = Test_X.reshape(Test_X.shape[0] * timesteps_TIM, r, w,
                                        channel)

        # Special Loading for 4-Channel
        if channel_flag == 1:
            Train_X_Strain, _, Test_X_Strain, _, _ = data_loader_with_LOSO(
                sub, SubperdB_strain, labelperSub, subjects)
            Train_X_Strain = Train_X_Strain.reshape(
                Train_X_Strain.shape[0] * timesteps_TIM, r, w, 1)
            Test_X_Strain = Test_X_Strain.reshape(
                Test_X.shape[0] * timesteps_TIM, r, w, 1)

            # Concatenate Train X & Train_X_Strain
            Train_X_spatial = np.concatenate((Train_X_spatial, Train_X_Strain),
                                             axis=3)
            Test_X_spatial = np.concatenate((Test_X_spatial, Test_X_Strain),
                                            axis=3)

            total_channel = 4

        elif channel_flag == 2:
            Train_X_Strain, _, Test_X_Strain, _, _ = data_loader_with_LOSO(
                sub, SubperdB_strain, labelperSub, subjects, classes)
            Train_X_gray, _, Test_X_gray, _, _ = data_loader_with_LOSO(
                sub, SubperdB_gray, labelperSub, subjects)
            Train_X_Strain = Train_X_Strain.reshape(
                Train_X_Strain.shape[0] * timesteps_TIM, r, w, 1)
            Test_X_Strain = Test_X_Strain.reshape(
                Test_X_Strain.shape[0] * timesteps_TIM, r, w, 1)
            Train_X_gray = Train_X_gray.reshape(
                Train_X_gray.shape[0] * timesteps_TIM, r, w, 3)
            Test_X_gray = Test_X_gray.reshape(
                Test_X_gray.shape[0] * timesteps_TIM, r, w, 3)

            # Concatenate Train_X_Strain & Train_X & Train_X_gray
            Train_X_spatial = np.concatenate(
                (Train_X_spatial, Train_X_Strain, Train_X_gray), axis=3)
            Test_X_spatial = np.concatenate(
                (Test_X_spatial, Test_X_Strain, Test_X_gray), axis=3)
            total_channel = 7

        if channel == 1:
            # Duplicate channel of input image
            Train_X_spatial = duplicate_channel(Train_X_spatial)
            Test_X_spatial = duplicate_channel(Test_X_spatial)

        # Extend Y labels 10 fold, so that all images have labels
        Train_Y_spatial = np.repeat(Train_Y, timesteps_TIM, axis=0)
        Test_Y_spatial = np.repeat(Test_Y, timesteps_TIM, axis=0)

        print("Train_X_shape: " + str(np.shape(Train_X_spatial)))
        print("Train_Y_shape: " + str(np.shape(Train_Y_spatial)))
        print("Test_X_shape: " + str(np.shape(Test_X_spatial)))
        print("Test_Y_shape: " + str(np.shape(Test_Y_spatial)))
        # print(Train_X_spatial)
        ##################### Training & Testing #########################

        X = Train_X_spatial.reshape(Train_X_spatial.shape[0], total_channel, r,
                                    w)
        y = Train_Y_spatial.reshape(Train_Y_spatial.shape[0], classes)
        normalized_X = X.astype('float32') / 255.

        test_X = Test_X_spatial.reshape(Test_X_spatial.shape[0], total_channel,
                                        r, w)
        test_y = Test_Y_spatial.reshape(Test_Y_spatial.shape[0], classes)
        normalized_test_X = test_X.astype('float32') / 255.

        print(X.shape)

        ###### conv weights must be freezed for transfer learning ######
        if finetuning_flag == 1:
            for layer in vgg_model.layers[:33]:
                layer.trainable = False
            for layer in vgg_model_cam.layers[:31]:
                layer.trainable = False

        if train_spatial_flag == 1 and train_temporal_flag == 1:

            # record f1 and loss
            file_loss = open(
                workplace + 'Classification/' + 'Result/' + dB + '/loss_' +
                str(train_id) + '.txt', 'a')
            file_loss.write(str(history.losses) + "\n")
            file_loss.close()

            file_loss = open(
                workplace + 'Classification/' + 'Result/' + dB + '/accuracy_' +
                str(train_id) + '.txt', 'a')
            file_loss.write(str(history.accuracy) + "\n")
            file_loss.close()

            model = Model(inputs=vgg_model.input,
                          outputs=vgg_model.layers[35].output)
            plot_model(model,
                       to_file="spatial_module_FULL_TRAINING.png",
                       show_shapes=True)

            # Testing
            output = model.predict(test_X, batch_size=batch_size)

            features = output.reshape(Test_X.shape[0], timesteps_TIM,
                                      output.shape[1])

            predict = temporal_model.predict_classes(features,
                                                     batch_size=batch_size)

        ##############################################################

        #################### Confusion Matrix Construction #############
        print(predict)
        print(Test_Y_gt)

        ct = confusion_matrix(Test_Y_gt, predict)
        # check the order of the CT
        order = np.unique(np.concatenate((predict, Test_Y_gt)))

        # create an array to hold the CT for each CV
        mat = np.zeros((n_exp, n_exp))
        # put the order accordingly, in order to form the overall ConfusionMat
        for m in range(len(order)):
            for n in range(len(order)):
                mat[int(order[m]), int(order[n])] = ct[m, n]

        tot_mat = mat + tot_mat
        ################################################################

        #################### cumulative f1 plotting ######################
        microAcc = np.trace(tot_mat) / np.sum(tot_mat)
        [f1, precision, recall] = fpr(tot_mat, n_exp)

        file = open(
            workplace + 'Classification/' + 'Result/' + dB + '/f1_' +
            str(train_id) + '.txt', 'a')
        file.write(str(f1) + "\n")
        file.close()
        ##################################################################

        ################# write each CT of each CV into .txt file #####################
        record_scores(workplace, dB, ct, sub, order, tot_mat, n_exp, subjects)
        ###############################################################################
示例#3
0
def train_vgg_lstm(batch_size, spatial_epochs, temporal_epochs, train_id,
                   list_dB, spatial_size, flag, objective_flag, tensorboard):
    ############## Path Preparation ######################
    root_db_path = "/media/ostalo/MihaGarafolj/ME_data/"
    tensorboard_path = root_db_path + "tensorboard/"
    if os.path.isdir(root_db_path + 'Weights/' + str(train_id)) == False:
        os.mkdir(root_db_path + 'Weights/' + str(train_id))

    ######################################################

    ############## Variables ###################
    dB = list_dB[0]
    r, w, subjects, samples, n_exp, VidPerSubject, vidList, timesteps_TIM, data_dim, channel, table, listOfIgnoredSamples, db_home, db_images, cross_db_flag = load_db(
        root_db_path, list_dB, spatial_size, objective_flag)

    # avoid confusion
    if cross_db_flag == 1:
        list_samples = listOfIgnoredSamples

    # total confusion matrix to be used in the computation of f1 score
    tot_mat = np.zeros((n_exp, n_exp))

    history = LossHistory()
    stopping = EarlyStopping(monitor='loss',
                             min_delta=0,
                             mode='min',
                             patience=3)

    ############################################

    ############## Flags ####################
    tensorboard_flag = tensorboard
    resizedFlag = 1
    train_spatial_flag = 0
    train_temporal_flag = 0
    svm_flag = 0
    finetuning_flag = 0
    cam_visualizer_flag = 0
    channel_flag = 0

    if flag == 'st':
        train_spatial_flag = 1
        train_temporal_flag = 1
        finetuning_flag = 1
    elif flag == 's':
        train_spatial_flag = 1
        finetuning_flag = 1
    elif flag == 't':
        train_temporal_flag = 1
    elif flag == 'nofine':
        svm_flag = 1
    elif flag == 'scratch':
        train_spatial_flag = 1
        train_temporal_flag = 1
    elif flag == 'st4se' or flag == 'st4se_cde':
        train_spatial_flag = 1
        train_temporal_flag = 1
        channel_flag = 1
    elif flag == 'st7se' or flag == 'st7se_cde':
        train_spatial_flag = 1
        train_temporal_flag = 1
        channel_flag = 2
    elif flag == 'st4te' or flag == 'st4te_cde':
        train_spatial_flag = 1
        train_temporal_flag = 1
        channel_flag = 3
    elif flag == 'st7te' or flag == 'st7te_cde':
        train_spatial_flag = 1
        train_temporal_flag = 1
        channel_flag = 4

    #########################################

    ############ Reading Images and Labels ################
    if cross_db_flag == 1:
        SubperdB = Read_Input_Images_SAMM_CASME(db_images, list_samples,
                                                listOfIgnoredSamples, dB,
                                                resizedFlag, table, db_home,
                                                spatial_size, channel)
    else:
        SubperdB = Read_Input_Images(db_images, listOfIgnoredSamples, dB,
                                     resizedFlag, table, db_home, spatial_size,
                                     channel, objective_flag)

    labelperSub = label_matching(db_home, dB, subjects, VidPerSubject)
    print("Loaded Images into the tray.")
    print("Loaded Labels into the tray.")

    #######################################################
    # PREPROCESSING STEPS
    # optical flow
    #SubperdB = optical_flow_2d(SubperdB, samples, r, w, timesteps_TIM)

    gc.collect()

    if channel_flag == 1:
        aux_db1 = list_dB[1]
        db_strain_img = root_db_path + aux_db1 + "/" + aux_db1 + "/"
        if cross_db_flag == 1:
            SubperdB = Read_Input_Images_SAMM_CASME(
                db_strain_img, list_samples, listOfIgnoredSamples, aux_db1,
                resizedFlag, table, db_home, spatial_size, 1)
        else:
            SubperdB_strain = Read_Input_Images(db_strain_img,
                                                listOfIgnoredSamples, aux_db1,
                                                resizedFlag, table, db_home,
                                                spatial_size, 1,
                                                objective_flag)

    elif channel_flag == 2:

        aux_db1 = list_dB[1]
        aux_db2 = list_dB[2]

        db_strain_img = root_db_path + aux_db1 + "/" + aux_db1 + "/"
        db_gray_img = root_db_path + aux_db2 + "/" + aux_db2 + "/"
        if cross_db_flag == 1:
            SubperdB_strain = Read_Input_Images_SAMM_CASME(
                db_strain_img, list_samples, listOfIgnoredSamples, aux_db1,
                resizedFlag, table, db_home, spatial_size, 1)
            SubperdB_gray = Read_Input_Images_SAMM_CASME(
                db_gray_img, list_samples, listOfIgnoredSamples, aux_db2,
                resizedFlag, table, db_home, spatial_size, 1)
        else:
            SubperdB_strain = Read_Input_Images(db_strain_img,
                                                listOfIgnoredSamples, aux_db1,
                                                resizedFlag, table, db_home,
                                                spatial_size, 1,
                                                objective_flag)
            SubperdB_gray = Read_Input_Images(db_gray_img,
                                              listOfIgnoredSamples, aux_db2,
                                              resizedFlag, table, db_home,
                                              spatial_size, 1, objective_flag)

    elif channel_flag == 3:
        aux_db1 = list_dB[1]
        db_strain_img = root_db_path + aux_db1 + "/" + aux_db1 + "/"
        if cross_db_flag == 1:
            SubperdB = Read_Input_Images_SAMM_CASME(
                db_strain_img, list_samples, listOfIgnoredSamples, aux_db1,
                resizedFlag, table, db_home, spatial_size, 3)
        else:
            SubperdB_strain = Read_Input_Images(db_strain_img,
                                                listOfIgnoredSamples, aux_db1,
                                                resizedFlag, table, db_home,
                                                spatial_size, 3,
                                                objective_flag)

    elif channel_flag == 4:
        aux_db1 = list_dB[1]
        aux_db2 = list_dB[2]
        db_strain_img = root_db_path + aux_db1 + "/" + aux_db1 + "/"
        db_gray_img = root_db_path + aux_db2 + "/" + aux_db2 + "/"
        if cross_db_flag == 1:
            SubperdB_strain = Read_Input_Images_SAMM_CASME(
                db_strain_img, list_samples, listOfIgnoredSamples, aux_db1,
                resizedFlag, table, db_home, spatial_size, 3)
            SubperdB_gray = Read_Input_Images_SAMM_CASME(
                db_gray_img, list_samples, listOfIgnoredSamples, aux_db2,
                resizedFlag, table, db_home, spatial_size, 3)
        else:
            SubperdB_strain = Read_Input_Images(db_strain_img,
                                                listOfIgnoredSamples, aux_db1,
                                                resizedFlag, table, db_home,
                                                spatial_size, 3,
                                                objective_flag)
            SubperdB_gray = Read_Input_Images(db_gray_img,
                                              listOfIgnoredSamples, aux_db2,
                                              resizedFlag, table, db_home,
                                              spatial_size, 3, objective_flag)

    #######################################################

    ########### Model Configurations #######################
    #K.set_image_dim_ordering('th')

    # config = tf.ConfigProto()
    # config.gpu_options.allow_growth = True
    # config.gpu_options.per_process_gpu_memory_fraction = 0.8
    # K.tensorflow_backend.set_session(tf.Session(config=config))

    # Different Conditions for Temporal Learning ONLY
    if train_spatial_flag == 0 and train_temporal_flag == 1 and dB != 'CASME2_Optical':
        data_dim = spatial_size * spatial_size
    elif train_spatial_flag == 0 and train_temporal_flag == 1 and dB == 'CASME2_Optical':
        data_dim = spatial_size * spatial_size * 3
    elif channel_flag == 3:
        data_dim = 8192
    elif channel_flag == 4:
        data_dim = 12288
    else:
        data_dim = 4096

    ########################################################

    print("Beginning training process.")
    ########### Training Process ############
    subjects_todo = read_subjects_todo(db_home, dB, train_id, subjects)

    for sub in subjects_todo:

        #sgd = optimizers.SGD(lr=0.0001, decay=1e-7, momentum=0.9, nesterov=True)
        adam = optimizers.Adam(lr=0.00001, decay=0.000001)

        print("**** starting subject " + str(sub) + " ****")
        #gpu_obgpu_observer()
        spatial_weights_name = root_db_path + 'Weights/' + str(
            train_id) + '/vgg_spatial_' + str(train_id) + '_' + str(dB) + '_'
        spatial_weights_name_strain = root_db_path + 'Weights/' + str(
            train_id) + '/vgg_spatial_strain_' + str(train_id) + '_' + str(
                dB) + '_'
        spatial_weights_name_gray = root_db_path + 'Weights/' + str(
            train_id) + '/vgg_spatial_gray_' + str(train_id) + '_' + str(
                dB) + '_'

        temporal_weights_name = root_db_path + 'Weights/' + str(
            train_id) + '/temporal_ID_' + str(train_id) + '_' + str(dB) + '_'

        ae_weights_name = root_db_path + 'Weights/' + str(
            train_id) + '/autoencoder_' + str(train_id) + '_' + str(dB) + '_'
        ae_weights_name_strain = root_db_path + 'Weights/' + str(
            train_id) + '/autoencoder_strain_' + str(train_id) + '_' + str(
                dB) + '_'

        ############### Reinitialization & weights reset of models ########################

        temporal_model = temporal_module(data_dim=data_dim,
                                         timesteps_TIM=timesteps_TIM,
                                         lstm1_size=3000,
                                         classes=n_exp)
        temporal_model.compile(loss='categorical_crossentropy',
                               optimizer=adam,
                               metrics=[metrics.categorical_accuracy])

        if channel_flag == 1:
            vgg_model = VGG_16_4_channels(classes=n_exp,
                                          channels=4,
                                          spatial_size=spatial_size)

            if finetuning_flag == 1:
                for layer in vgg_model.layers[:33]:
                    layer.trainable = False

            vgg_model.compile(loss='categorical_crossentropy',
                              optimizer=adam,
                              metrics=[metrics.categorical_accuracy])

        elif channel_flag == 2:
            vgg_model = VGG_16_4_channels(classes=n_exp,
                                          channels=5,
                                          spatial_size=spatial_size)

            if finetuning_flag == 1:
                for layer in vgg_model.layers[:33]:
                    layer.trainable = False

            vgg_model.compile(loss='categorical_crossentropy',
                              optimizer=adam,
                              metrics=[metrics.categorical_accuracy])

        elif channel_flag == 3 or channel_flag == 4:
            vgg_model = VGG_16(spatial_size=spatial_size,
                               classes=n_exp,
                               channels=3,
                               weights_path='VGG_Face_Deep_16.h5')
            vgg_model_strain = VGG_16(spatial_size=spatial_size,
                                      classes=n_exp,
                                      channels=3,
                                      weights_path='VGG_Face_Deep_16.h5')

            if finetuning_flag == 1:
                for layer in vgg_model.layers[:33]:
                    layer.trainable = False
                for layer in vgg_model_strain.layers[:33]:
                    layer.trainable = False

            vgg_model.compile(loss='categorical_crossentropy',
                              optimizer=adam,
                              metrics=[metrics.categorical_accuracy])
            vgg_model_strain.compile(loss='categorical_crossentropy',
                                     optimizer=adam,
                                     metrics=[metrics.categorical_accuracy])

            if channel_flag == 4:
                vgg_model_gray = VGG_16(spatial_size=spatial_size,
                                        classes=n_exp,
                                        channels=3,
                                        weights_path='VGG_Face_Deep_16.h5')

                if finetuning_flag == 1:
                    for layer in vgg_model_gray.layers[:33]:
                        layer.trainable = False

                vgg_model_gray.compile(loss='categorical_crossentropy',
                                       optimizer=adam,
                                       metrics=[metrics.categorical_accuracy])

        else:
            vgg_model = VGG_16(spatial_size=spatial_size,
                               classes=n_exp,
                               channels=channel,
                               channel_first=False,
                               weights_path='VGG_Face_Deep_16.h5')

            if finetuning_flag == 1:
                for layer in vgg_model.layers[:33]:
                    layer.trainable = False

            vgg_model.compile(loss='categorical_crossentropy',
                              optimizer=adam,
                              metrics=[metrics.categorical_accuracy])

        #svm_classifier = SVC(kernel='linear', C=1)
        ####################################################################################

        ############ for tensorboard ###############
        if tensorboard_flag == 1:
            cat_path = tensorboard_path + str(train_id) + str(sub) + "/"
            if os.path.exists(cat_path):
                os.rmdir(cat_path)
            os.mkdir(cat_path)
            tbCallBack = keras.callbacks.TensorBoard(log_dir=cat_path,
                                                     write_graph=True)

            cat_path2 = tensorboard_path + str(train_id) + str(
                sub) + "spatial/"
            if os.path.exists(cat_path2):
                os.rmdir(cat_path2)
            os.mkdir(cat_path2)
            tbCallBack2 = keras.callbacks.TensorBoard(log_dir=cat_path2,
                                                      write_graph=True)
        #############################################

        #import ipdb; ipdb.set_trace()
        Train_X, Train_Y, Test_X, Test_Y, Test_Y_gt, X, y, test_X, test_y = restructure_data(
            sub, SubperdB, labelperSub, subjects, n_exp, r, w, timesteps_TIM,
            channel)

        # Special Loading for 4-Channel
        if channel_flag == 1:
            _, _, _, _, _, Train_X_Strain, Train_Y_Strain, Test_X_Strain, Test_Y_Strain = restructure_data(
                sub, SubperdB_strain, labelperSub, subjects, n_exp, r, w,
                timesteps_TIM, 1)

            # verify
            # sanity_check_image(Test_X_Strain, 1, spatial_size)

            # Concatenate Train X & Train_X_Strain
            X = np.concatenate((X, Train_X_Strain), axis=1)
            test_X = np.concatenate((test_X, Test_X_Strain), axis=1)

            total_channel = 4

        elif channel_flag == 2:
            _, _, _, _, _, Train_X_Strain, Train_Y_Strain, Test_X_Strain, Test_Y_Strain = restructure_data(
                sub, SubperdB_strain, labelperSub, subjects, n_exp, r, w,
                timesteps_TIM, 1)

            _, _, _, _, _, Train_X_Gray, Train_Y_Gray, Test_X_Gray, Test_Y_Gray = restructure_data(
                sub, SubperdB_gray, labelperSub, subjects, n_exp, r, w,
                timesteps_TIM, 1)

            # Concatenate Train_X_Strain & Train_X & Train_X_Gray
            X = np.concatenate((X, Train_X_Strain, Train_X_Gray), axis=1)
            test_X = np.concatenate((test_X, Test_X_Strain, Test_X_Gray),
                                    axis=1)

            total_channel = 5

        elif channel_flag == 3:
            _, _, _, _, _, Train_X_Strain, Train_Y_Strain, Test_X_Strain, Test_Y_Strain = restructure_data(
                sub, SubperdB_strain, labelperSub, subjects, n_exp, r, w,
                timesteps_TIM, 3)

        elif channel_flag == 4:
            _, _, _, _, _, Train_X_Strain, Train_Y_Strain, Test_X_Strain, Test_Y_Strain = restructure_data(
                sub, SubperdB_strain, labelperSub, subjects, n_exp, r, w,
                timesteps_TIM, 3)
            _, _, _, _, _, Train_X_Gray, Train_Y_Gray, Test_X_Gray, Test_Y_Gray = restructure_data(
                sub, SubperdB_gray, labelperSub, subjects, n_exp, r, w,
                timesteps_TIM, 3)

        ############### check gpu resources ####################
        #gpu_observer()
        ########################################################

        print("Beginning training & testing.")
        ##################### Training & Testing #########################

        if train_spatial_flag == 1 and train_temporal_flag == 1:

            print("Beginning spatial training.")
            # Spatial Training
            if tensorboard_flag == 1:
                vgg_model.fit(X,
                              y,
                              batch_size=batch_size,
                              epochs=spatial_epochs,
                              shuffle=True,
                              callbacks=[history, stopping, tbCallBack2])

            elif channel_flag == 3 or channel_flag == 4:
                vgg_model.fit(X,
                              y,
                              batch_size=batch_size,
                              epochs=spatial_epochs,
                              shuffle=True,
                              callbacks=[history, stopping])
                vgg_model_strain.fit(Train_X_Strain,
                                     y,
                                     batch_size=batch_size,
                                     epochs=spatial_epochs,
                                     shuffle=True,
                                     callbacks=[stopping])
                model_strain = record_weights(vgg_model_strain,
                                              spatial_weights_name_strain, sub,
                                              flag)
                output_strain = model_strain.predict(Train_X_Strain,
                                                     batch_size=batch_size)
                if channel_flag == 4:
                    vgg_model_gray.fit(Train_X_Gray,
                                       y,
                                       batch_size=batch_size,
                                       epochs=spatial_epochs,
                                       shuffle=True,
                                       callbacks=[stopping])
                    model_gray = record_weights(vgg_model_gray,
                                                spatial_weights_name_gray, sub,
                                                flag)
                    output_gray = model_gray.predict(Train_X_Gray,
                                                     batch_size=batch_size)

            else:
                #import ipdb; ipdb.set_trace()
                vgg_model.fit(X,
                              y,
                              batch_size=batch_size,
                              epochs=spatial_epochs,
                              shuffle=True,
                              callbacks=[history, stopping])

            print(".record f1 and loss")
            # record f1 and loss
            record_loss_accuracy(db_home, train_id, dB, history)

            print(".save vgg weights")
            # save vgg weights
            # model = record_weights(vgg_model, spatial_weights_name, sub, flag)

            print(".spatial encoding")
            # Spatial Encoding
            model_int = Model(inputs=vgg_model.input,
                              outputs=vgg_model.get_layer('dense_prvi').output)

            #model = record_weights(cnn_model, spatial_weights_name, sub, flag)
            features = model_int.predict(X, batch_size=batch_size)
            features = features.reshape(int(Train_X.shape[0]), timesteps_TIM,
                                        features.shape[1])

            # concatenate features for temporal enrichment
            if channel_flag == 3:
                output = np.concatenate((output, output_strain), axis=1)
            elif channel_flag == 4:
                output = np.concatenate((output, output_strain, output_gray),
                                        axis=1)

            print("Beginning temporal training.")
            # Temporal Training
            if tensorboard_flag == 1:
                temporal_model.fit(features,
                                   Train_Y,
                                   batch_size=batch_size,
                                   epochs=temporal_epochs,
                                   callbacks=[tbCallBack])
            else:
                temporal_model.fit(features,
                                   Train_Y,
                                   batch_size=batch_size,
                                   epochs=temporal_epochs)

            print(".save temportal weights")
            # save temporal weights
            temporal_model = record_weights(temporal_model,
                                            temporal_weights_name, sub,
                                            't')  # let the flag be t

            print("Beginning testing.")
            print(".predicting with spatial model")
            # Testing
            features = model_int.predict(test_X, batch_size=batch_size)
            features = features.reshape(int(Test_X.shape[0]), timesteps_TIM,
                                        features.shape[1])

            if channel_flag == 3 or channel_flag == 4:
                output_strain = model_strain.predict(Test_X_Strain,
                                                     batch_size=batch_size)
                if channel_flag == 4:
                    output_gray = model_gray.predict(Test_X_Gray,
                                                     batch_size=batch_size)

            # concatenate features for temporal enrichment
            if channel_flag == 3:
                output = np.concatenate((output, output_strain), axis=1)
            elif channel_flag == 4:
                output = np.concatenate((output, output_strain, output_gray),
                                        axis=1)

            print(".outputing features")
            print(".predicting with temporal model")
            predict_values = temporal_model.predict(features,
                                                    batch_size=batch_size)
            predict = np.array([np.argmax(x) for x in predict_values])
        ##############################################################

        #################### Confusion Matrix Construction #############
        print(predict)
        print(Test_Y_gt.astype(int))

        print(".writing predicts to file")
        file = open(
            db_home + 'Classification/' + 'Result/' + '/predicts_' +
            str(train_id) + '.txt', 'a')
        for i in range(len(vidList[sub])):
            file.write("sub_" + str(sub) + "," + str(vidList[sub][i]) + "," +
                       str(predict.astype(list)[i]) + "," +
                       str(Test_Y_gt.astype(int).astype(list)[i]) + "\n")
        file.close()

        file = open(
            db_home + 'Classification/' + 'Result/' + '/predictedvalues_' +
            str(train_id) + '.txt', 'a')
        for i in range(len(vidList[sub])):
            file.write("sub_" + str(sub) + "," + str(vidList[sub][i]) + "," +
                       ','.join(str(e) for e in predict_values[i]) + "," +
                       str(Test_Y_gt.astype(int).astype(list)[i]) + "\n")
        file.close()

        ct = confusion_matrix(Test_Y_gt, predict)
        # check the order of the CT
        order = np.unique(np.concatenate((predict, Test_Y_gt)))

        # create an array to hold the CT for each CV
        mat = np.zeros((n_exp, n_exp))
        # put the order accordingly, in order to form the overall ConfusionMat
        for m in range(len(order)):
            for n in range(len(order)):
                mat[int(order[m]), int(order[n])] = ct[m, n]

        tot_mat = mat + tot_mat
        ################################################################

        #################### cumulative f1 plotting ######################
        microAcc = np.trace(tot_mat) / np.sum(tot_mat)
        [f1, precision, recall] = fpr(tot_mat, n_exp)

        file = open(
            db_home + 'Classification/' + 'Result/' + '/f1_' + str(train_id) +
            '.txt', 'a')
        file.write(str(f1) + "\n")
        file.close()
        ##################################################################

        ################# write each CT of each CV into .txt file #####################
        record_scores(db_home, dB, ct, sub, order, tot_mat, n_exp, subjects)
        war = weighted_average_recall(tot_mat, n_exp, samples)
        uar = unweighted_average_recall(tot_mat, n_exp)
        print("war: " + str(war))
        print("uar: " + str(uar))
        ###############################################################################

        ################## free memory ####################

        del vgg_model
        del temporal_model
        del model_int
        del Train_X, Test_X, X, y

        if channel_flag == 1:
            del Train_X_Strain, Test_X_Strain, Train_Y_Strain, Test_Y_Strain
        elif channel_flag == 2:
            del Train_X_Strain, Test_X_Strain, Train_Y_Strain, Test_Y_Strain, Train_X_Gray, Test_X_Gray, Train_Y_Gray, Test_Y_Gray
        elif channel_flag == 3:
            del vgg_model_strain, model_strain
            del Train_X_Strain, Test_X_Strain, Train_Y_Strain, Test_Y_Strain
        elif channel_flag == 4:
            del Train_X_Strain, Test_X_Strain, Train_Y_Strain, Test_Y_Strain, Train_X_Gray, Test_X_Gray, Train_Y_Gray, Test_Y_Gray
            del vgg_model_gray, vgg_model_strain, model_gray, model_strain

        K.get_session().close()
        cfg = K.tf.ConfigProto()
        cfg.gpu_options.allow_growth = True
        K.set_session(K.tf.Session(config=cfg))

        gc.collect()
示例#4
0



	############### Reinitialization & weights reset of models ########################

	vgg_model_cam = VGG_16(spatial_size=spatial_size, classes=classes, weights_path='VGG_Face_Deep_16.h5')

	temporal_model = temporal_module(data_dim=data_dim, classes=classes, timesteps_TIM=timesteps_TIM)
	temporal_model.compile(loss='categorical_crossentropy', optimizer=adam, metrics=[metrics.categorical_accuracy])

	conv_ae = convolutional_autoencoder(spatial_size = spatial_size, classes = classes)
	conv_ae.compile(loss='binary_crossentropy', optimizer=adam)

	if channel_flag == 1 or channel_flag == 2:
		vgg_model = VGG_16_4_channels(classes=classes, spatial_size = spatial_size)
		vgg_model.compile(loss='categorical_crossentropy', optimizer=adam, metrics=[metrics.categorical_accuracy])
	else:
		vgg_model = VGG_16(spatial_size = spatial_size, classes=classes, weights_path='VGG_Face_Deep_16.h5')
		vgg_model.compile(loss='categorical_crossentropy', optimizer=adam, metrics=[metrics.categorical_accuracy])

	svm_classifier = SVC(kernel='linear', C=1)
		####################################################################################
		
		
		############ for tensorboard ###############
	if tensorboard_flag == 1:
		cat_path = tensorboard_path + str(sub) + "/"
		os.mkdir(cat_path)
		tbCallBack = keras.callbacks.TensorBoard(log_dir=cat_path, write_graph=True)
def train(batch_size, spatial_epochs, temporal_epochs, train_id, list_dB,
          spatial_size, flag, objective_flag, tensorboard):
    ############## Path Preparation ######################
    root_db_path = "/media/ice/OS/Datasets/"
    tensorboard_path = "/home/ice/Documents/Micro-Expression/tensorboard/"
    if os.path.isdir(root_db_path + 'Weights/' + str(train_id)) == False:
        os.mkdir(root_db_path + 'Weights/' + str(train_id))

    ######################################################

    ############## Variables ###################
    dB = list_dB[0]
    r, w, subjects, samples, n_exp, VidPerSubject, timesteps_TIM, data_dim, channel, table, listOfIgnoredSamples, db_home, db_images, cross_db_flag = load_db(
        root_db_path, list_dB, spatial_size, objective_flag)

    # avoid confusion
    if cross_db_flag == 1:
        list_samples = listOfIgnoredSamples

    # total confusion matrix to be used in the computation of f1 score
    tot_mat = np.zeros((n_exp, n_exp))

    history = LossHistory()
    stopping = EarlyStopping(monitor='loss', min_delta=0, mode='min')

    ############################################

    ############## Flags ####################
    tensorboard_flag = tensorboard
    resizedFlag = 1
    train_spatial_flag = 0
    train_temporal_flag = 0
    svm_flag = 0
    finetuning_flag = 0
    cam_visualizer_flag = 0
    channel_flag = 0

    if flag == 'st':
        train_spatial_flag = 1
        train_temporal_flag = 1
        finetuning_flag = 1
    elif flag == 's':
        train_spatial_flag = 1
        finetuning_flag = 1
    elif flag == 't':
        train_temporal_flag = 1
    elif flag == 'nofine':
        svm_flag = 1
    elif flag == 'scratch':
        train_spatial_flag = 1
        train_temporal_flag = 1
    elif flag == 'st4se' or flag == 'st4se_cde':
        train_spatial_flag = 1
        train_temporal_flag = 1
        channel_flag = 1
    elif flag == 'st7se' or flag == 'st7se_cde':
        train_spatial_flag = 1
        train_temporal_flag = 1
        channel_flag = 2
    elif flag == 'st4te' or flag == 'st4te_cde':
        train_spatial_flag = 1
        train_temporal_flag = 1
        channel_flag = 3
    elif flag == 'st7te' or flag == 'st7te_cde':
        train_spatial_flag = 1
        train_temporal_flag = 1
        channel_flag = 4

    #########################################

    ############ Reading Images and Labels ################
    if cross_db_flag == 1:
        SubperdB = Read_Input_Images_SAMM_CASME(db_images, list_samples,
                                                listOfIgnoredSamples, dB,
                                                resizedFlag, table, db_home,
                                                spatial_size, channel)
    else:
        SubperdB = Read_Input_Images(db_images, listOfIgnoredSamples, dB,
                                     resizedFlag, table, db_home, spatial_size,
                                     channel, objective_flag)

    labelperSub = label_matching(db_home, dB, subjects, VidPerSubject)
    print("Loaded Images into the tray...")
    print("Loaded Labels into the tray...")

    if channel_flag == 1:
        aux_db1 = list_dB[1]
        db_strain_img = root_db_path + aux_db1 + "/" + aux_db1 + "/"
        if cross_db_flag == 1:
            SubperdB = Read_Input_Images_SAMM_CASME(
                db_strain_img, list_samples, listOfIgnoredSamples, aux_db1,
                resizedFlag, table, db_home, spatial_size, 1)
        else:
            SubperdB_strain = Read_Input_Images(db_strain_img,
                                                listOfIgnoredSamples, aux_db1,
                                                resizedFlag, table, db_home,
                                                spatial_size, 1,
                                                objective_flag)

    elif channel_flag == 2:
        aux_db1 = list_dB[1]
        aux_db2 = list_dB[2]
        db_strain_img = root_db_path + aux_db1 + "/" + aux_db1 + "/"
        db_gray_img = root_db_path + aux_db2 + "/" + aux_db2 + "/"
        if cross_db_flag == 1:
            SubperdB_strain = Read_Input_Images_SAMM_CASME(
                db_strain_img, list_samples, listOfIgnoredSamples, aux_db1,
                resizedFlag, table, db_home, spatial_size, 1)
            SubperdB_gray = Read_Input_Images_SAMM_CASME(
                db_gray_img, list_samples, listOfIgnoredSamples, aux_db2,
                resizedFlag, table, db_home, spatial_size, 1)
        else:
            SubperdB_strain = Read_Input_Images(db_strain_img,
                                                listOfIgnoredSamples, aux_db1,
                                                resizedFlag, table, db_home,
                                                spatial_size, 1,
                                                objective_flag)
            SubperdB_gray = Read_Input_Images(db_gray_img,
                                              listOfIgnoredSamples, aux_db2,
                                              resizedFlag, table, db_home,
                                              spatial_size, 1, objective_flag)

    elif channel_flag == 3:
        aux_db1 = list_dB[1]
        db_strain_img = root_db_path + aux_db1 + "/" + aux_db1 + "/"
        if cross_db_flag == 1:
            SubperdB = Read_Input_Images_SAMM_CASME(
                db_strain_img, list_samples, listOfIgnoredSamples, aux_db1,
                resizedFlag, table, db_home, spatial_size, 3)
        else:
            SubperdB_strain = Read_Input_Images(db_strain_img,
                                                listOfIgnoredSamples, aux_db1,
                                                resizedFlag, table, db_home,
                                                spatial_size, 3,
                                                objective_flag)

    elif channel_flag == 4:
        aux_db1 = list_dB[1]
        aux_db2 = list_dB[2]
        db_strain_img = root_db_path + aux_db1 + "/" + aux_db1 + "/"
        db_gray_img = root_db_path + aux_db2 + "/" + aux_db2 + "/"
        if cross_db_flag == 1:
            SubperdB_strain = Read_Input_Images_SAMM_CASME(
                db_strain_img, list_samples, listOfIgnoredSamples, aux_db1,
                resizedFlag, table, db_home, spatial_size, 3)
            SubperdB_gray = Read_Input_Images_SAMM_CASME(
                db_gray_img, list_samples, listOfIgnoredSamples, aux_db2,
                resizedFlag, table, db_home, spatial_size, 3)
        else:
            SubperdB_strain = Read_Input_Images(db_strain_img,
                                                listOfIgnoredSamples, aux_db1,
                                                resizedFlag, table, db_home,
                                                spatial_size, 3,
                                                objective_flag)
            SubperdB_gray = Read_Input_Images(db_gray_img,
                                              listOfIgnoredSamples, aux_db2,
                                              resizedFlag, table, db_home,
                                              spatial_size, 3, objective_flag)

    #######################################################

    ########### Model Configurations #######################
    sgd = optimizers.SGD(lr=0.0001, decay=1e-7, momentum=0.9, nesterov=True)
    adam = optimizers.Adam(lr=0.00001, decay=0.000001)

    # Different Conditions for Temporal Learning ONLY
    if train_spatial_flag == 0 and train_temporal_flag == 1 and dB != 'CASME2_Optical':
        data_dim = spatial_size * spatial_size
    elif train_spatial_flag == 0 and train_temporal_flag == 1 and dB == 'CASME2_Optical':
        data_dim = spatial_size * spatial_size * 3
    elif channel_flag == 3:
        data_dim = 8192
    elif channel_flag == 4:
        data_dim = 12288
    else:
        data_dim = 4096

    ########################################################

    ########### Training Process ############

    for sub in range(subjects):

        spatial_weights_name = root_db_path + 'Weights/' + str(
            train_id) + '/vgg_spatial_' + str(train_id) + '_' + str(dB) + '_'
        spatial_weights_name_strain = root_db_path + 'Weights/' + str(
            train_id) + '/vgg_spatial_strain_' + str(train_id) + '_' + str(
                dB) + '_'
        spatial_weights_name_gray = root_db_path + 'Weights/' + str(
            train_id) + '/vgg_spatial_gray_' + str(train_id) + '_' + str(
                dB) + '_'

        temporal_weights_name = root_db_path + 'Weights/' + str(
            train_id) + '/temporal_ID_' + str(train_id) + '_' + str(dB) + '_'

        ae_weights_name = root_db_path + 'Weights/' + str(
            train_id) + '/autoencoder_' + str(train_id) + '_' + str(dB) + '_'
        ae_weights_name_strain = root_db_path + 'Weights/' + str(
            train_id) + '/autoencoder_strain_' + str(train_id) + '_' + str(
                dB) + '_'

        ############### Reinitialization & weights reset of models ########################

        temporal_model = temporal_module(data_dim=data_dim,
                                         timesteps_TIM=timesteps_TIM,
                                         classes=n_exp)
        temporal_model.compile(loss='categorical_crossentropy',
                               optimizer=adam,
                               metrics=[metrics.categorical_accuracy])

        conv_ae = convolutional_autoencoder(spatial_size=spatial_size,
                                            classes=n_exp)
        conv_ae.compile(loss='binary_crossentropy', optimizer=adam)

        if channel_flag == 1:
            vgg_model = VGG_16_4_channels(classes=n_exp,
                                          channels=4,
                                          spatial_size=spatial_size)
            vgg_model.compile(loss='categorical_crossentropy',
                              optimizer=adam,
                              metrics=[metrics.categorical_accuracy])

        elif channel_flag == 2:
            vgg_model = VGG_16_4_channels(classes=n_exp,
                                          channels=5,
                                          spatial_size=spatial_size)
            vgg_model.compile(loss='categorical_crossentropy',
                              optimizer=adam,
                              metrics=[metrics.categorical_accuracy])

        elif channel_flag == 3 or channel_flag == 4:
            vgg_model = VGG_16(spatial_size=spatial_size,
                               classes=n_exp,
                               channels=3,
                               weights_path='VGG_Face_Deep_16.h5')
            vgg_model.compile(loss='categorical_crossentropy',
                              optimizer=adam,
                              metrics=[metrics.categorical_accuracy])

            vgg_model_strain = VGG_16(spatial_size=spatial_size,
                                      classes=n_exp,
                                      channels=3,
                                      weights_path='VGG_Face_Deep_16.h5')
            vgg_model_strain.compile(loss='categorical_crossentropy',
                                     optimizer=adam,
                                     metrics=[metrics.categorical_accuracy])

            if channel_flag == 4:
                vgg_model_gray = VGG_16(spatial_size=spatial_size,
                                        classes=n_exp,
                                        channels=3,
                                        weights_path='VGG_Face_Deep_16.h5')
                vgg_model_gray.compile(loss='categorical_crossentropy',
                                       optimizer=adam,
                                       metrics=[metrics.categorical_accuracy])

        else:
            vgg_model = VGG_16(spatial_size=spatial_size,
                               classes=n_exp,
                               channels=3,
                               weights_path='VGG_Face_Deep_16.h5')
            vgg_model.compile(loss='categorical_crossentropy',
                              optimizer=adam,
                              metrics=[metrics.categorical_accuracy])

        svm_classifier = SVC(kernel='linear', C=1)
        ####################################################################################

        ############ for tensorboard ###############
        if tensorboard_flag == 1:
            cat_path = tensorboard_path + str(sub) + "/"
            os.mkdir(cat_path)
            tbCallBack = keras.callbacks.TensorBoard(log_dir=cat_path,
                                                     write_graph=True)

            cat_path2 = tensorboard_path + str(sub) + "spat/"
            os.mkdir(cat_path2)
            tbCallBack2 = keras.callbacks.TensorBoard(log_dir=cat_path2,
                                                      write_graph=True)
        #############################################

        Train_X, Train_Y, Test_X, Test_Y, Test_Y_gt, X, y, test_X, test_y = restructure_data(
            sub, SubperdB, labelperSub, subjects, n_exp, r, w, timesteps_TIM,
            channel)

        # Special Loading for 4-Channel
        if channel_flag == 1:
            _, _, _, _, _, Train_X_Strain, Train_Y_Strain, Test_X_Strain, Test_Y_Strain = restructure_data(
                sub, SubperdB_strain, labelperSub, subjects, n_exp, r, w,
                timesteps_TIM, 1)

            # verify
            # sanity_check_image(Test_X_Strain, 1, spatial_size)

            # Concatenate Train X & Train_X_Strain
            X = np.concatenate((X, Train_X_Strain), axis=1)
            test_X = np.concatenate((test_X, Test_X_Strain), axis=1)

            total_channel = 4

        elif channel_flag == 2:
            _, _, _, _, _, Train_X_Strain, Train_Y_Strain, Test_X_Strain, Test_Y_Strain = restructure_data(
                sub, SubperdB_strain, labelperSub, subjects, n_exp, r, w,
                timesteps_TIM, 1)

            _, _, _, _, _, Train_X_Gray, Train_Y_Gray, Test_X_Gray, Test_Y_Gray = restructure_data(
                sub, SubperdB_gray, labelperSub, subjects, n_exp, r, w,
                timesteps_TIM, 1)

            # Concatenate Train_X_Strain & Train_X & Train_X_gray
            X = np.concatenate((X, Train_X_Strain, Train_X_gray), axis=1)
            test_X = np.concatenate((test_X, Test_X_Strain, Test_X_gray),
                                    axis=1)

            total_channel = 5

        elif channel_flag == 3:
            _, _, _, _, _, Train_X_Strain, Train_Y_Strain, Test_X_Strain, Test_Y_Strain = restructure_data(
                sub, SubperdB_strain, labelperSub, subjects, n_exp, r, w,
                timesteps_TIM, 3)

        elif channel_flag == 4:
            _, _, _, _, _, Train_X_Strain, Train_Y_Strain, Test_X_Strain, Test_Y_Strain = restructure_data(
                sub, SubperdB_strain, labelperSub, subjects, n_exp, r, w,
                timesteps_TIM, 3)
            _, _, _, _, _, Train_X_Gray, Train_Y_Gray, Test_X_Gray, Test_Y_Gray = restructure_data(
                sub, SubperdB_gray, labelperSub, subjects, n_exp, r, w,
                timesteps_TIM, 3)

        ############### check gpu resources ####################
        gpu_observer()
        ########################################################

        ##################### Training & Testing #########################
        # conv weights must be freezed for transfer learning
        if finetuning_flag == 1:
            for layer in vgg_model.layers[:33]:
                layer.trainable = False
            if channel_flag == 3 or channel_flag == 4:
                for layer in vgg_model_strain.layers[:33]:
                    layer.trainable = False
                if channel_flag == 4:
                    for layer in vgg_model_gray.layers[:33]:
                        layer.trainable = False

        if train_spatial_flag == 1 and train_temporal_flag == 1:

            # Spatial Training
            if tensorboard_flag == 1:
                vgg_model.fit(X,
                              y,
                              batch_size=batch_size,
                              epochs=spatial_epochs,
                              shuffle=True,
                              callbacks=[tbCallBack2])

            elif channel_flag == 3 or channel_flag == 4:
                vgg_model.fit(X,
                              y,
                              batch_size=batch_size,
                              epochs=spatial_epochs,
                              shuffle=True,
                              callbacks=[history, stopping])
                vgg_model_strain.fit(Train_X_Strain,
                                     y,
                                     batch_size=batch_size,
                                     epochs=spatial_epochs,
                                     shuffle=True,
                                     callbacks=[stopping])
                model_strain = record_weights(vgg_model_strain,
                                              spatial_weights_name_strain, sub,
                                              flag)
                output_strain = model_strain.predict(Train_X_Strain,
                                                     batch_size=batch_size)
                if channel_flag == 4:
                    vgg_model_gray.fit(Train_X_Gray,
                                       y,
                                       batch_size=batch_size,
                                       epochs=spatial_epochs,
                                       shuffle=True,
                                       callbacks=[stopping])
                    model_gray = record_weights(vgg_model_gray,
                                                spatial_weights_name_gray, sub,
                                                flag)
                    output_gray = model_gray.predict(Train_X_Gray,
                                                     batch_size=batch_size)

            else:
                vgg_model.fit(X,
                              y,
                              batch_size=batch_size,
                              epochs=spatial_epochs,
                              shuffle=True,
                              callbacks=[history, stopping])

            # record f1 and loss
            record_loss_accuracy(db_home, train_id, dB, history)

            # save vgg weights
            model = record_weights(vgg_model, spatial_weights_name, sub, flag)

            # Spatial Encoding
            output = model.predict(X, batch_size=batch_size)

            # concatenate features for temporal enrichment
            if channel_flag == 3:
                output = np.concatenate((output, output_strain), axis=1)
            elif channel_flag == 4:
                output = np.concatenate((output, output_strain, output_gray),
                                        axis=1)

            features = output.reshape(int(Train_X.shape[0]), timesteps_TIM,
                                      output.shape[1])

            # Temporal Training
            if tensorboard_flag == 1:
                temporal_model.fit(features,
                                   Train_Y,
                                   batch_size=batch_size,
                                   epochs=temporal_epochs,
                                   callbacks=[tbCallBack])
            else:
                temporal_model.fit(features,
                                   Train_Y,
                                   batch_size=batch_size,
                                   epochs=temporal_epochs)

            # save temporal weights
            temporal_model = record_weights(temporal_model,
                                            temporal_weights_name, sub,
                                            't')  # let the flag be t

            # Testing
            output = model.predict(test_X, batch_size=batch_size)
            if channel_flag == 3 or channel_flag == 4:
                output_strain = model_strain.predict(Test_X_Strain,
                                                     batch_size=batch_size)
                if channel_flag == 4:
                    output_gray = model_gray.predict(Test_X_Gray,
                                                     batch_size=batch_size)

            # concatenate features for temporal enrichment
            if channel_flag == 3:
                output = np.concatenate((output, output_strain), axis=1)
            elif channel_flag == 4:
                output = np.concatenate((output, output_strain, output_gray),
                                        axis=1)

            features = output.reshape(Test_X.shape[0], timesteps_TIM,
                                      output.shape[1])

            predict = temporal_model.predict_classes(features,
                                                     batch_size=batch_size)
        ##############################################################

        #################### Confusion Matrix Construction #############
        print(predict)
        print(Test_Y_gt)

        ct = confusion_matrix(Test_Y_gt, predict)
        # check the order of the CT
        order = np.unique(np.concatenate((predict, Test_Y_gt)))

        # create an array to hold the CT for each CV
        mat = np.zeros((n_exp, n_exp))
        # put the order accordingly, in order to form the overall ConfusionMat
        for m in range(len(order)):
            for n in range(len(order)):
                mat[int(order[m]), int(order[n])] = ct[m, n]

        tot_mat = mat + tot_mat
        ################################################################

        #################### cumulative f1 plotting ######################
        microAcc = np.trace(tot_mat) / np.sum(tot_mat)
        [f1, precision, recall] = fpr(tot_mat, n_exp)

        file = open(
            db_home + 'Classification/' + 'Result/' + dB + '/f1_' +
            str(train_id) + '.txt', 'a')
        file.write(str(f1) + "\n")
        file.close()
        ##################################################################

        ################# write each CT of each CV into .txt file #####################
        record_scores(db_home, dB, ct, sub, order, tot_mat, n_exp, subjects)
        war = weighted_average_recall(tot_mat, n_exp, samples)
        uar = unweighted_average_recall(tot_mat, n_exp)
        print("war: " + str(war))
        print("uar: " + str(uar))
        ###############################################################################

        ################## free memory ####################

        del vgg_model
        del temporal_model
        del model
        del Train_X, Test_X, X, y

        if channel_flag == 1:
            del Train_X_Strain, Test_X_Strain, Train_Y_Strain, Train_Y_Strain
        elif channel_flag == 2:
            del Train_X_Strain, Test_X_Strain, Train_Y_Strain, Train_Y_Strain, Train_X_Gray, Test_X_Gray, Train_Y_Gray, Test_Y_Gray
        elif channel_flag == 3:
            del vgg_model_strain, model_strain
            del Train_X_Strain, Test_X_Strain, Train_Y_Strain, Train_Y_Strain
        elif channel_flag == 4:
            del Train_X_Strain, Test_X_Strain, Train_Y_Strain, Train_Y_Strain, Train_X_Gray, Test_X_Gray, Train_Y_Gray, Test_Y_Gray
            del vgg_model_gray, vgg_model_strain, model_gray, model_strain

        gc.collect()
        ###################################################
示例#6
0
def train_cae_lstm(batch_size, spatial_epochs, temporal_epochs, train_id, dB,
                   spatial_size, flag, tensorboard):
    ############## Path Preparation ######################
    root_db_path = "/media/ice/OS/Datasets/"
    workplace = root_db_path + dB + "/"
    inputDir = root_db_path + dB + "/" + dB + "/"
    ######################################################

    r, w, subjects, samples, n_exp, VidPerSubject, timesteps_TIM, timesteps_TIM, data_dim, channel, table, listOfIgnoredSamples = load_db(
        root_db_path, dB, spatial_size)

    # print(VidPerSubject)

    ############## Flags ####################
    tensorboard_flag = tensorboard
    resizedFlag = 1
    train_spatial_flag = 0
    train_temporal_flag = 0
    svm_flag = 0
    finetuning_flag = 0
    cam_visualizer_flag = 0
    channel_flag = 0

    if flag == 'st':
        train_spatial_flag = 1
        train_temporal_flag = 1
        finetuning_flag = 1
    elif flag == 's':
        train_spatial_flag = 1
        finetuning_flag = 1
    elif flag == 't':
        train_temporal_flag = 1
    elif flag == 'nofine':
        svm_flag = 1
    elif flag == 'scratch':
        train_spatial_flag = 1
        train_temporal_flag = 1
    elif flag == 'st4':
        train_spatial_flag = 1
        train_temporal_flag = 1
        channel_flag = 1
    elif flag == 'st7':
        train_spatial_flag = 1
        train_temporal_flag = 1
        channel_flag = 2
    #########################################

    ############ Reading Images and Labels ################
    SubperdB = Read_Input_Images(inputDir, listOfIgnoredSamples, dB,
                                 resizedFlag, table, workplace, spatial_size,
                                 channel)
    print("Loaded Images into the tray...")
    labelperSub = label_matching(workplace, dB, subjects, VidPerSubject)
    print("Loaded Labels into the tray...")

    if channel_flag == 1:
        SubperdB_strain = Read_Input_Images(inputDir, listOfIgnoredSamples,
                                            'CASME2_Strain_TIM10', resizedFlag,
                                            table, workplace, spatial_size, 1)
    elif channel_flag == 2:
        SubperdB_strain = Read_Input_Images(inputDir, listOfIgnoredSamples,
                                            'CASME2_Strain_TIM10', resizedFlag,
                                            table, workplace, spatial_size, 1)
        SubperdB_gray = Read_Input_Images(inputDir, listOfIgnoredSamples,
                                          'CASME2_TIM', resizedFlag, table,
                                          workplace, spatial_size, 3)
    #######################################################

    ########### Model Configurations #######################
    sgd = optimizers.SGD(lr=0.0001, decay=1e-7, momentum=0.9, nesterov=True)
    adam = optimizers.Adam(lr=0.00001, decay=0.000001)

    # Different Conditions for Temporal Learning ONLY
    if train_spatial_flag == 0 and train_temporal_flag == 1 and dB != 'CASME2_Optical':
        data_dim = spatial_size * spatial_size
    elif train_spatial_flag == 0 and train_temporal_flag == 1 and dB == 'CASME2_Optical':
        data_dim = spatial_size * spatial_size * 3
    else:
        data_dim = 4096

    ########################################################

    ########### Training Process ############
    # Todo:
    # 1) LOSO (done)
    # 2) call model (done)
    # 3) saving model architecture
    # 4) Saving Checkpoint (done)
    # 5) make prediction (done)
    if tensorboard_flag == 1:
        tensorboard_path = "/home/ice/Documents/Micro-Expression/tensorboard/"

    # total confusion matrix to be used in the computation of f1 score
    tot_mat = np.zeros((n_exp, n_exp))

    # model checkpoint
    spatial_weights_name = 'vgg_spatial_' + str(train_id) + '_' + str(dB) + '_'
    temporal_weights_name = 'temporal_ID_' + str(train_id) + '_' + str(
        dB) + '_'
    ae_weights_name = 'autoencoder_' + str(train_id) + '_' + str(dB) + '_'
    history = LossHistory()
    stopping = EarlyStopping(monitor='loss', min_delta=0, mode='min')

    for sub in range(subjects):
        ############### Reinitialization & weights reset of models ########################

        vgg_model_cam = VGG_16(spatial_size=spatial_size,
                               classes=classes,
                               weights_path='VGG_Face_Deep_16.h5')

        temporal_model = temporal_module(data_dim=data_dim,
                                         classes=classes,
                                         timesteps_TIM=timesteps_TIM)
        temporal_model.compile(loss='categorical_crossentropy',
                               optimizer=adam,
                               metrics=[metrics.categorical_accuracy])

        conv_ae = convolutional_autoencoder(spatial_size=spatial_size,
                                            classes=classes)
        conv_ae.compile(loss='binary_crossentropy', optimizer=adam)

        if channel_flag == 1 or channel_flag == 2:
            vgg_model = VGG_16_4_channels(classes=classes,
                                          spatial_size=spatial_size)
            vgg_model.compile(loss='categorical_crossentropy',
                              optimizer=adam,
                              metrics=[metrics.categorical_accuracy])
        else:
            vgg_model = VGG_16(spatial_size=spatial_size,
                               classes=classes,
                               weights_path='VGG_Face_Deep_16.h5')
            vgg_model.compile(loss='categorical_crossentropy',
                              optimizer=adam,
                              metrics=[metrics.categorical_accuracy])

        svm_classifier = SVC(kernel='linear', C=1)
        ####################################################################################

        ############ for tensorboard ###############
        if tensorboard_flag == 1:
            cat_path = tensorboard_path + str(sub) + "/"
            os.mkdir(cat_path)
            tbCallBack = keras.callbacks.TensorBoard(log_dir=cat_path,
                                                     write_graph=True)

            cat_path2 = tensorboard_path + str(sub) + "spat/"
            os.mkdir(cat_path2)
            tbCallBack2 = keras.callbacks.TensorBoard(log_dir=cat_path2,
                                                      write_graph=True)
        #############################################

        image_label_mapping = np.empty([0])

        Train_X, Train_Y, Test_X, Test_Y, Test_Y_gt = data_loader_with_LOSO(
            sub, SubperdB, labelperSub, subjects, classes)

        # Rearrange Training labels into a vector of images, breaking sequence
        Train_X_spatial = Train_X.reshape(Train_X.shape[0] * timesteps_TIM, r,
                                          w, channel)
        Test_X_spatial = Test_X.reshape(Test_X.shape[0] * timesteps_TIM, r, w,
                                        channel)

        # Special Loading for 4-Channel
        if channel_flag == 1:
            Train_X_Strain, _, Test_X_Strain, _, _ = data_loader_with_LOSO(
                sub, SubperdB_strain, labelperSub, subjects, classes)
            Train_X_Strain = Train_X_Strain.reshape(
                Train_X_Strain.shape[0] * timesteps_TIM, r, w, 1)
            Test_X_Strain = Test_X_Strain.reshape(
                Test_X.shape[0] * timesteps_TIM, r, w, 1)

            # Concatenate Train X & Train_X_Strain
            Train_X_spatial = np.concatenate((Train_X_spatial, Train_X_Strain),
                                             axis=3)
            Test_X_spatial = np.concatenate((Test_X_spatial, Test_X_Strain),
                                            axis=3)

            channel = 4

        elif channel_flag == 2:
            Train_X_Strain, _, Test_X_Strain, _, _ = data_loader_with_LOSO(
                sub, SubperdB_strain, labelperSub, subjects, classes)
            Train_X_gray, _, Test_X_gray, _, _ = data_loader_with_LOSO(
                sub, SubperdB_gray, labelperSub, subjects)
            Train_X_Strain = Train_X_Strain.reshape(
                Train_X_Strain.shape[0] * timesteps_TIM, r, w, 1)
            Test_X_Strain = Test_X_Strain.reshape(
                Test_X_Strain.shape[0] * timesteps_TIM, r, w, 1)
            Train_X_gray = Train_X_gray.reshape(
                Train_X_gray.shape[0] * timesteps_TIM, r, w, 3)
            Test_X_gray = Test_X_gray.reshape(
                Test_X_gray.shape[0] * timesteps_TIM, r, w, 3)

            # Concatenate Train_X_Strain & Train_X & Train_X_gray
            Train_X_spatial = np.concatenate(
                (Train_X_spatial, Train_X_Strain, Train_X_gray), axis=3)
            Test_X_spatial = np.concatenate(
                (Test_X_spatial, Test_X_Strain, Test_X_gray), axis=3)
            channel = 7

        if channel == 1:
            # Duplicate channel of input image
            Train_X_spatial = duplicate_channel(Train_X_spatial)
            Test_X_spatial = duplicate_channel(Test_X_spatial)

        # Extend Y labels 10 fold, so that all images have labels
        Train_Y_spatial = np.repeat(Train_Y, timesteps_TIM, axis=0)
        Test_Y_spatial = np.repeat(Test_Y, timesteps_TIM, axis=0)

        # print ("Train_X_shape: " + str(np.shape(Train_X_spatial)))
        # print ("Train_Y_shape: " + str(np.shape(Train_Y_spatial)))
        # print ("Test_X_shape: " + str(np.shape(Test_X_spatial)))
        # print ("Test_Y_shape: " + str(np.shape(Test_Y_spatial)))
        # print(Train_X_spatial)
        ##################### Training & Testing #########################

        # print(Train_X_spatial.shape)

        X = Train_X_spatial.reshape(Train_X_spatial.shape[0], channel, r, w)
        y = Train_Y_spatial.reshape(Train_Y_spatial.shape[0], classes)
        normalized_X = X.astype('float32') / 255.

        test_X = Test_X_spatial.reshape(Test_X_spatial.shape[0], channel, r, w)
        test_y = Test_Y_spatial.reshape(Test_Y_spatial.shape[0], classes)
        normalized_test_X = test_X.astype('float32') / 255.

        print(X.shape)

        ###### conv weights must be freezed for transfer learning ######
        if finetuning_flag == 1:
            for layer in vgg_model.layers[:33]:
                layer.trainable = False
            for layer in vgg_model_cam.layers[:31]:
                layer.trainable = False

        if train_spatial_flag == 1 and train_temporal_flag == 1:
            # Autoencoder first training
            conv_ae.fit(normalized_X,
                        normalized_X,
                        batch_size=batch_size,
                        epochs=spatial_epochs,
                        shuffle=True)

            # remove decoder
            conv_ae.pop()
            conv_ae.pop()
            conv_ae.pop()
            conv_ae.pop()
            conv_ae.pop()
            conv_ae.pop()
            conv_ae.pop()

            # append dense layers
            conv_ae.add(Flatten())
            conv_ae.add(Dense(4096, activation='relu'))
            conv_ae.add(Dense(4096, activation='relu'))
            conv_ae.add(Dense(n_exp, activation='sigmoid'))
            model_ae = Model(inputs=conv_ae.input,
                             outputs=conv_ae.layers[9].output)
            plot_model(model_ae, to_file='autoencoders.png', show_shapes=True)

            # freeze encoder
            for layer in conv_ae.layers[:6]:
                layer.trainable = False

            # finetune dense layers
            conv_ae.compile(loss='categorical_crossentropy', optimizer=adam)
            conv_ae.fit(normalized_X,
                        y,
                        batch_size=batch_size,
                        epochs=spatial_epochs,
                        shuffle=True)

            model_ae = Model(inputs=conv_ae.input,
                             outputs=conv_ae.layers[8].output)
            plot_model(model_ae, to_file='autoencoders.png', show_shapes=True)

            # Autoencoding
            output = model_ae.predict(normalized_X, batch_size=batch_size)

            # print(output.shape)
            features = output.reshape(int(Train_X.shape[0]), timesteps_TIM,
                                      output.shape[1])

            temporal_model.fit(features,
                               Train_Y,
                               batch_size=batch_size,
                               epochs=temporal_epochs)

            temporal_model.save_weights(temporal_weights_name + str(sub) +
                                        ".h5")

            # Testing
            output = model_ae.predict(test_X, batch_size=batch_size)

            features = output.reshape(Test_X.shape[0], timesteps_TIM,
                                      output.shape[1])

            predict = temporal_model.predict_classes(features,
                                                     batch_size=batch_size)

        #################### Confusion Matrix Construction #############
        print(predict)
        print(Test_Y_gt)

        ct = confusion_matrix(Test_Y_gt, predict)
        # check the order of the CT
        order = np.unique(np.concatenate((predict, Test_Y_gt)))

        # create an array to hold the CT for each CV
        mat = np.zeros((n_exp, n_exp))
        # put the order accordingly, in order to form the overall ConfusionMat
        for m in range(len(order)):
            for n in range(len(order)):
                mat[int(order[m]), int(order[n])] = ct[m, n]

        tot_mat = mat + tot_mat
        ################################################################

        #################### cumulative f1 plotting ######################
        microAcc = np.trace(tot_mat) / np.sum(tot_mat)
        [f1, precision, recall] = fpr(tot_mat, n_exp)

        file = open(
            workplace + 'Classification/' + 'Result/' + dB + '/f1_' +
            str(train_id) + '.txt', 'a')
        file.write(str(f1) + "\n")
        file.close()
        ##################################################################

        ################# write each CT of each CV into .txt file #####################
        record_scores(workplace, dB, ct, sub, order, tot_mat, n_exp, subjects)
        ###############################################################################
示例#7
0
def train_samm(batch_size, spatial_epochs, temporal_epochs, train_id, dB,
               spatial_size, flag, tensorboard):
    ############## Path Preparation ######################
    root_db_path = "/media/ice/OS/Datasets/"
    workplace = root_db_path + dB + "/"
    inputDir = root_db_path + dB + "/" + dB + "/"
    ######################################################
    classes = 5
    if dB == 'CASME2_TIM':
        table = loading_casme_table(workplace + 'CASME2-ObjectiveClasses.xlsx')
        listOfIgnoredSamples, IgnoredSamples_index = ignore_casme_samples(
            inputDir)

        ############## Variables ###################
        r = w = spatial_size
        subjects = 2
        n_exp = 5
        # VidPerSubject = get_subfolders_num(inputDir, IgnoredSamples_index)
        listOfIgnoredSamples = []
        VidPerSubject = [2, 1]
        timesteps_TIM = 10
        data_dim = r * w
        pad_sequence = 10
        channel = 3
        ############################################

        os.remove(workplace + "Classification/CASME2_TIM_label.txt")

    elif dB == 'CASME2_Optical':
        table = loading_casme_table(workplace + 'CASME2-ObjectiveClasses.xlsx')
        listOfIgnoredSamples, IgnoredSamples_index = ignore_casme_samples(
            inputDir)

        ############## Variables ###################
        r = w = spatial_size
        subjects = 26
        n_exp = 5
        VidPerSubject = get_subfolders_num(inputDir, IgnoredSamples_index)
        timesteps_TIM = 9
        data_dim = r * w
        pad_sequence = 9
        channel = 3
        ############################################

        # os.remove(workplace + "Classification/CASME2_TIM_label.txt")

    elif dB == 'SAMM_TIM10':
        table, table_objective = loading_samm_table(root_db_path, dB)
        listOfIgnoredSamples = []
        IgnoredSamples_index = np.empty([0])

        ################# Variables #############################
        r = w = spatial_size
        subjects = 29
        n_exp = 8
        VidPerSubject = get_subfolders_num(inputDir, IgnoredSamples_index)
        timesteps_TIM = 10
        data_dim = r * w
        pad_sequence = 10
        channel = 3
        classes = 8
        #########################################################

    elif dB == 'SAMM_CASME_Strain':
        # total amount of videos 253
        table, table_objective = loading_samm_table(root_db_path, dB)
        table = table_objective
        table2 = loading_casme_objective_table(root_db_path, dB)

        # merge samm and casme tables
        table = np.concatenate((table, table2), axis=1)

        # print(table.shape)

        # listOfIgnoredSamples, IgnoredSamples_index, sub_items = ignore_casme_samples(inputDir)
        listOfIgnoredSamples = []
        IgnoredSamples_index = np.empty([0])
        sub_items = np.empty([0])
        list_samples = filter_objective_samples(table)

        r = w = spatial_size
        subjects = 47  # some subjects were removed because of objective classes and ignore samples: 47
        n_exp = 5
        # TODO:
        # 1) Further decrease the video amount, the one with objective classes >= 6
        # list samples: samples with wanted objective class
        VidPerSubject, list_samples = get_subfolders_num_crossdb(
            inputDir, IgnoredSamples_index, sub_items, table, list_samples)

        # print(VidPerSubject)
        # print(len(VidPerSubject))
        # print(sum(VidPerSubject))
        timesteps_TIM = 9
        data_dim = r * w
        channel = 3
        classes = 5
        if os.path.isfile(workplace +
                          "Classification/SAMM_CASME_Optical_label.txt"):
            os.remove(workplace +
                      "Classification/SAMM_CASME_Optical_label.txt")
        ##################### Variables ######################

        ######################################################

    ############## Flags ####################
    tensorboard_flag = tensorboard
    resizedFlag = 1
    train_spatial_flag = 0
    train_temporal_flag = 0
    svm_flag = 0
    finetuning_flag = 0
    cam_visualizer_flag = 0
    channel_flag = 0

    if flag == 'st':
        train_spatial_flag = 1
        train_temporal_flag = 1
        finetuning_flag = 1
    elif flag == 's':
        train_spatial_flag = 1
        finetuning_flag = 1
    elif flag == 't':
        train_temporal_flag = 1
    elif flag == 'nofine':
        svm_flag = 1
    elif flag == 'scratch':
        train_spatial_flag = 1
        train_temporal_flag = 1
    elif flag == 'st4':
        train_spatial_flag = 1
        train_temporal_flag = 1
        channel_flag = 1
    elif flag == 'st7':
        train_spatial_flag = 1
        train_temporal_flag = 1
        channel_flag = 2
    #########################################

    ############ Reading Images and Labels ################

    SubperdB = Read_Input_Images_SAMM_CASME(inputDir, list_samples,
                                            listOfIgnoredSamples, dB,
                                            resizedFlag, table, workplace,
                                            spatial_size, channel)
    print("Loaded Images into the tray...")
    labelperSub = label_matching(workplace, dB, subjects, VidPerSubject)
    print("Loaded Labels into the tray...")

    if channel_flag == 1:
        SubperdB_strain = Read_Input_Images(inputDir, listOfIgnoredSamples,
                                            'CASME2_Strain_TIM10', resizedFlag,
                                            table, workplace, spatial_size, 1)
    elif channel_flag == 2:
        SubperdB_strain = Read_Input_Images(inputDir, listOfIgnoredSamples,
                                            'CASME2_Strain_TIM10', resizedFlag,
                                            table, workplace, spatial_size, 1)
        SubperdB_gray = Read_Input_Images(inputDir, listOfIgnoredSamples,
                                          'CASME2_TIM', resizedFlag, table,
                                          workplace, spatial_size, 3)
    #######################################################

    ########### Model Configurations #######################
    sgd = optimizers.SGD(lr=0.0001, decay=1e-7, momentum=0.9, nesterov=True)
    adam = optimizers.Adam(lr=0.00001, decay=0.000001)
    adam2 = optimizers.Adam(lr=0.00075, decay=0.0001)

    # Different Conditions for Temporal Learning ONLY
    if train_spatial_flag == 0 and train_temporal_flag == 1 and dB != 'CASME2_Optical':
        data_dim = spatial_size * spatial_size
    elif train_spatial_flag == 0 and train_temporal_flag == 1 and dB == 'CASME2_Optical':
        data_dim = spatial_size * spatial_size * 3
    else:
        data_dim = 4096

    ########################################################

    ########### Training Process ############

    # total confusion matrix to be used in the computation of f1 score
    tot_mat = np.zeros((n_exp, n_exp))

    # model checkpoint
    spatial_weights_name = 'vgg_spatial_' + str(train_id) + '_casme2_'
    temporal_weights_name = 'temporal_ID_' + str(train_id) + '_casme2_'
    history = LossHistory()
    stopping = EarlyStopping(monitor='loss', min_delta=0, mode='min')

    for sub in range(subjects):
        ############### Reinitialization & weights reset of models ########################

        vgg_model_cam = VGG_16(spatial_size=spatial_size,
                               classes=classes,
                               weights_path='VGG_Face_Deep_16.h5')

        temporal_model = temporal_module(data_dim=data_dim,
                                         classes=classes,
                                         timesteps_TIM=timesteps_TIM)
        temporal_model.compile(loss='categorical_crossentropy',
                               optimizer=adam,
                               metrics=[metrics.categorical_accuracy])

        conv_ae = convolutional_autoencoder(spatial_size=spatial_size,
                                            classes=classes)
        conv_ae.compile(loss='binary_crossentropy', optimizer=adam)

        if channel_flag == 1 or channel_flag == 2:
            vgg_model = VGG_16_4_channels(classes=classes,
                                          spatial_size=spatial_size)
            vgg_model.compile(loss='categorical_crossentropy',
                              optimizer=adam,
                              metrics=[metrics.categorical_accuracy])
        else:
            vgg_model = VGG_16(spatial_size=spatial_size,
                               classes=classes,
                               weights_path='VGG_Face_Deep_16.h5')
            vgg_model.compile(loss='categorical_crossentropy',
                              optimizer=adam,
                              metrics=[metrics.categorical_accuracy])

        svm_classifier = SVC(kernel='linear', C=1)
        ####################################################################################

        ############ for tensorboard ###############
        if tensorboard_flag == 1:
            cat_path = tensorboard_path + str(sub) + "/"
            os.mkdir(cat_path)
            tbCallBack = keras.callbacks.TensorBoard(log_dir=cat_path,
                                                     write_graph=True)

            cat_path2 = tensorboard_path + str(sub) + "spat/"
            os.mkdir(cat_path2)
            tbCallBack2 = keras.callbacks.TensorBoard(log_dir=cat_path2,
                                                      write_graph=True)
        #############################################

        image_label_mapping = np.empty([0])

        Train_X, Train_Y, Test_X, Test_Y, Test_Y_gt = data_loader_with_LOSO(
            sub, SubperdB, labelperSub, subjects, classes)

        # Rearrange Training labels into a vector of images, breaking sequence
        Train_X_spatial = Train_X.reshape(Train_X.shape[0] * timesteps_TIM, r,
                                          w, channel)
        Test_X_spatial = Test_X.reshape(Test_X.shape[0] * timesteps_TIM, r, w,
                                        channel)

        # Special Loading for 4-Channel
        if channel_flag == 1:
            Train_X_Strain, _, Test_X_Strain, _, _ = data_loader_with_LOSO(
                sub, SubperdB_strain, labelperSub, subjects, classes)
            Train_X_Strain = Train_X_Strain.reshape(
                Train_X_Strain.shape[0] * timesteps_TIM, r, w, 1)
            Test_X_Strain = Test_X_Strain.reshape(
                Test_X.shape[0] * timesteps_TIM, r, w, 1)

            # Concatenate Train X & Train_X_Strain
            Train_X_spatial = np.concatenate((Train_X_spatial, Train_X_Strain),
                                             axis=3)
            Test_X_spatial = np.concatenate((Test_X_spatial, Test_X_Strain),
                                            axis=3)

            channel = 4

        elif channel_flag == 2:
            Train_X_Strain, _, Test_X_Strain, _, _ = data_loader_with_LOSO(
                sub, SubperdB_strain, labelperSub, subjects, classes)
            Train_X_gray, _, Test_X_gray, _, _ = data_loader_with_LOSO(
                sub, SubperdB_gray, labelperSub, subjects)
            Train_X_Strain = Train_X_Strain.reshape(
                Train_X_Strain.shape[0] * timesteps_TIM, r, w, 1)
            Test_X_Strain = Test_X_Strain.reshape(
                Test_X_Strain.shape[0] * timesteps_TIM, r, w, 1)
            Train_X_gray = Train_X_gray.reshape(
                Train_X_gray.shape[0] * timesteps_TIM, r, w, 3)
            Test_X_gray = Test_X_gray.reshape(
                Test_X_gray.shape[0] * timesteps_TIM, r, w, 3)

            # Concatenate Train_X_Strain & Train_X & Train_X_gray
            Train_X_spatial = np.concatenate(
                (Train_X_spatial, Train_X_Strain, Train_X_gray), axis=3)
            Test_X_spatial = np.concatenate(
                (Test_X_spatial, Test_X_Strain, Test_X_gray), axis=3)
            channel = 7

        if channel == 1:
            # Duplicate channel of input image
            Train_X_spatial = duplicate_channel(Train_X_spatial)
            Test_X_spatial = duplicate_channel(Test_X_spatial)

        # Extend Y labels 10 fold, so that all images have labels
        Train_Y_spatial = np.repeat(Train_Y, timesteps_TIM, axis=0)
        Test_Y_spatial = np.repeat(Test_Y, timesteps_TIM, axis=0)

        # print ("Train_X_shape: " + str(np.shape(Train_X_spatial)))
        # print ("Train_Y_shape: " + str(np.shape(Train_Y_spatial)))
        # print ("Test_X_shape: " + str(np.shape(Test_X_spatial)))
        # print ("Test_Y_shape: " + str(np.shape(Test_Y_spatial)))
        # print(Train_X_spatial)
        ##################### Training & Testing #########################

        X = Train_X_spatial.reshape(Train_X_spatial.shape[0], channel, r, w)
        y = Train_Y_spatial.reshape(Train_Y_spatial.shape[0], classes)
        normalized_X = X.astype('float32') / 255.

        test_X = Test_X_spatial.reshape(Test_X_spatial.shape[0], channel, r, w)
        test_y = Test_Y_spatial.reshape(Test_Y_spatial.shape[0], classes)
        normalized_test_X = test_X.astype('float32') / 255.

        print(X.shape)

        ###### conv weights must be freezed for transfer learning ######
        if finetuning_flag == 1:
            for layer in vgg_model.layers[:33]:
                layer.trainable = False

        if train_spatial_flag == 1 and train_temporal_flag == 1:
            # Autoencoder features
            # conv_ae.fit(normalized_X, normalized_X, batch_size=batch_size, epochs=spatial_epochs, shuffle=True)

            # Spatial Training
            if tensorboard_flag == 1:
                vgg_model.fit(X,
                              y,
                              batch_size=batch_size,
                              epochs=spatial_epochs,
                              shuffle=True,
                              callbacks=[tbCallBack2])
            else:
                vgg_model.fit(X,
                              y,
                              batch_size=batch_size,
                              epochs=spatial_epochs,
                              shuffle=True,
                              callbacks=[history, stopping])

            # record f1 and loss
            file_loss = open(
                workplace + 'Classification/' + 'Result/' + dB + '/loss_' +
                str(train_id) + '.txt', 'a')
            file_loss.write(str(history.losses) + "\n")
            file_loss.close()

            file_loss = open(
                workplace + 'Classification/' + 'Result/' + dB + '/accuracy_' +
                str(train_id) + '.txt', 'a')
            file_loss.write(str(history.accuracy) + "\n")
            file_loss.close()

            vgg_model.save_weights(spatial_weights_name + str(sub) + ".h5")