def eval(): g1 = tf.Graph() with g1.as_default(): config = tf.ConfigProto( allow_soft_placement=True, log_device_placement=False) # 如果有gpu,则优先用gpu,否则走cpu资源 config.gpu_options.allow_growth = True with tf.Session(config=config) as sess: data_object = DataProcessing(args) valid_batch_example = data_object.input_frame_data( frame_path=args.valid_path, batch_size=160, num_epoch=1) model = creat_model(sess, args, isTraining=False) # 构建模型计算图 # app_tag_model函数: 模型计算图的具体实现 ckpt = tf.train.get_checkpoint_state(args.model_dir) # 判断是否存在模型 if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path): # print("Reloading model parameters..") print("valid step: Reading model parameters from {}".format( ckpt.model_checkpoint_path)) saver = tf.train.Saver(tf.global_variables()) saver.restore(sess=sess, save_path=ckpt.model_checkpoint_path ) # 调用saver接口,将各个tensor变量的值赋给对应的tensor print(sess.run(model.global_step)) else: if not os.path.exists(args.model_dir): os.makedirs(args.model_dir) print("valid step: Created new model parameters..") sess.run(tf.global_variables_initializer()) valid_cate2_perr_list = [] valid_cate2_gap_list = [] valid_total_loss_list = [] try: while True: context_parsed, sequence_parsed = sess.run( valid_batch_example) batch_origin_labels = [ np.nonzero(row)[0].tolist() for row in context_parsed['labels'] ] cate1_multilabel, cate2_multilabel, batch_origin_cate1, batch_origin_cate2 = data_object.get_cate1_cate2_label( batch_origin_labels) batch_vid_name = np.asarray(context_parsed['id']) batch_cate1_label_multiHot = np.asarray( cate1_multilabel) # batch,cate1_nums batch_cate2_label_multiHot = np.asarray( cate2_multilabel) # batch,cate2_nums batch_rgb_fea_float_list = np.asarray( sequence_parsed['rgb']) # batch,max_frame,1024 batch_audio_fea_float_list = np.asarray( sequence_parsed['audio']) # batch,max_frame,128 batch_num_audio_rgb_true_frame = np.asarray( context_parsed['num_audio_rgb_true_frame']) feed = dict( zip([ model.input_video_vidName, model.input_cate1_multilabel, model.input_cate2_multilabel, model.input_video_RGB_feature, model.input_video_Audio_feature, model.input_rgb_audio_true_frame, model.dropout_keep_prob ], [ batch_vid_name, batch_cate1_label_multiHot, batch_cate2_label_multiHot, batch_rgb_fea_float_list, batch_audio_fea_float_list, batch_num_audio_rgb_true_frame, 1.0 ])) cate2_probs, total_loss = sess.run( [model.cate2_probs, model.total_loss], feed) cate2_perr = eval_util.calculate_precision_at_equal_recall_rate( cate2_probs, batch_cate2_label_multiHot) cate2_gap = eval_util.calculate_gap( cate2_probs, batch_cate2_label_multiHot) valid_cate2_perr_list.append(cate2_perr) valid_cate2_gap_list.append(cate2_gap) valid_total_loss_list.append(total_loss) except tf.errors.OutOfRangeError: print("end!") valid_cate2_perr_aver_loss = 1.0 * np.sum( valid_cate2_perr_list) / len(valid_cate2_perr_list) valid_cate2_gap_aver_loss = 1.0 * np.sum( valid_cate2_gap_list) / len(valid_cate2_gap_list) valid_total_valid_aver_loss = 1.0 * np.sum( valid_total_loss_list) / len(valid_total_loss_list) print('total valid cate2_perr_aver_loss: %0.4f' % valid_cate2_perr_aver_loss) print('total valid cate2_gap_aver_loss: %0.4f' % valid_cate2_gap_aver_loss) print('***********************') print('total valid total_valid_aver_loss: %0.4f' % valid_total_valid_aver_loss) return valid_cate2_gap_aver_loss
def predict_test(): if os.path.exists(args.predict_out_path): remove_file(args.predict_out_path) with open(args.predict_out_path, 'a+') as fout1: fout1.write("VideoId,LabelConfidencePairs" + '\n') data_object = DataProcessing(args) test_batch_example = data_object.input_frame_data( frame_path=args.test_path, batch_size=200, num_epoch=1) config = tf.ConfigProto( allow_soft_placement=True, log_device_placement=False) # 如果有gpu,则优先用gpu,否则走cpu资源 config.gpu_options.allow_growth = True with tf.Session(config=config) as sess: model = creat_model(sess, args, isTraining=False) # 构建模型计算图 num = 0 try: while True: batch_vid, batch_index, batch_value = [], [], [] context_parsed, sequence_parsed = sess.run(test_batch_example) batch_origin_labels = [ np.nonzero(row)[0].tolist() for row in context_parsed['labels'] ] cate1_multilabel, cate2_multilabel, batch_origin_cate1, batch_origin_cate2 = data_object.get_cate1_cate2_label( batch_origin_labels) batch_vid_name = np.asarray(context_parsed['id']) batch_cate1_label_multiHot = np.asarray( cate1_multilabel) # batch,cate1_nums batch_cate2_label_multiHot = np.asarray( cate2_multilabel) # batch,cate2_nums batch_rgb_fea_float_list = np.asarray( sequence_parsed['rgb']) # batch,max_frame,1024 batch_audio_fea_float_list = np.asarray( sequence_parsed['audio']) # batch,max_frame,128 batch_num_audio_rgb_true_frame = np.asarray( context_parsed['num_audio_rgb_true_frame']) feed = dict( zip([ model.input_video_vidName, model.input_cate1_multilabel, model.input_cate2_multilabel, model.input_video_RGB_feature, model.input_video_Audio_feature, model.input_rgb_audio_true_frame, model.dropout_keep_prob ], [ batch_vid_name, batch_cate1_label_multiHot, batch_cate2_label_multiHot, batch_rgb_fea_float_list, batch_audio_fea_float_list, batch_num_audio_rgb_true_frame, 1.0 ])) cate2_top20_probs_index, cate2_top20_probs_value = sess.run([ model.cate2_top20_probs_index, model.cate2_top20_probs_value ], feed) num += 1 print("validSet step_num: ", num) for one_batch_index in cate2_top20_probs_index: batch_index.append(list(one_batch_index)) for one_batch_value in cate2_top20_probs_value: batch_value.append(list(one_batch_value)) for one_batch_vid in batch_vid_name: batch_vid.append(list(one_batch_vid)) with open(args.predict_out_path, 'a+') as fout2: for vid, index, value in zip( *[batch_vid, batch_index, batch_value]): one_result = [] tmp_index_vale_list = [] vid = vid[0].decode(encoding='utf-8') one_result.append(vid) one_result.append(',') for k, v in zip(*[index, value]): tmp_index_vale_list.append(str(k)) tmp_index_vale_list.append(str("%.6f" % v)) one_result.append(' '.join(tmp_index_vale_list)) # print("one_result: ",''.join(one_result)) fout2.write(''.join(one_result) + '\n') batch_vid, batch_index, batch_value = [], [], [] except tf.errors.OutOfRangeError: print("train processing is finished!")
def train(): ''' 模型,训练任务 :return: ''' valid_max_accuracy = -9999 config = tf.ConfigProto( allow_soft_placement=True, log_device_placement=False) # 如果有gpu,则优先用gpu,否则走cpu资源 config.gpu_options.allow_growth = True with tf.Session(config=config) as sess: data_object = DataProcessing(args) train_batch_example = data_object.input_frame_data( frame_path=args.train_path, batch_size=args.batch_size, num_epoch=args.epoch) model = creat_model(sess, args, isTraining=True) # 构建模型计算图 saver = tf.train.Saver(tf.global_variables(), max_to_keep=3) # max_to_keep 表征只保留最好的3个模型 print("Begin training..") train_cate1_perr_list = [] train_cate1_gap_list = [] train_cate2_perr_list = [] train_cate2_gap_list = [] train_total_loss_list = [] try: while True: context_parsed, sequence_parsed = sess.run(train_batch_example) batch_origin_labels = [ np.nonzero(row)[0].tolist() for row in context_parsed['labels'] ] cate1_multilabel, cate2_multilabel, batch_origin_cate1, batch_origin_cate2 = data_object.get_cate1_cate2_label( batch_origin_labels) batch_vid_name = np.asarray(context_parsed['id']) batch_num_audio_rgb_true_frame = np.asarray( context_parsed['num_audio_rgb_true_frame']) batch_cate1_label_multiHot = np.asarray( cate1_multilabel) # batch,cate1_nums batch_cate2_label_multiHot = np.asarray( cate2_multilabel) # batch,cate2_nums batch_rgb_fea_float_list = np.asarray( sequence_parsed['rgb']) # batch,max_frame,1024 batch_audio_fea_float_list = np.asarray( sequence_parsed['audio']) # batch,max_frame,1 # print("batch_vid_name.shape:",batch_vid_name.shape) # print("batch_cate1_label_multiHot.shape: ",batch_cate1_label_multiHot.shape) # print("batch_cate2_label_multiHot.shape: ",batch_cate2_label_multiHot.shape) # print("batch_rgb_fea_float_list.shape: ",batch_rgb_fea_float_list.shape) # print("batch_audio_fea_float_list.shape: ",batch_audio_fea_float_list.shape) # print("batch_num_audio_rgb_true_frame: ",batch_num_audio_rgb_true_frame) # print("batch_num_audio_rgb_true_frame: ", np.asarray(batch_num_audio_rgb_true_frame).shape) # assert 1==2 feed = dict( zip([ model.input_video_vidName, model.input_cate1_multilabel, model.input_cate2_multilabel, model.input_video_RGB_feature, model.input_video_Audio_feature, model.input_rgb_audio_true_frame, model.dropout_keep_prob ], [ batch_vid_name, batch_cate1_label_multiHot, batch_cate2_label_multiHot, batch_rgb_fea_float_list, batch_audio_fea_float_list, batch_num_audio_rgb_true_frame, 0.5 ])) cate1_probs, cate2_probs, total_loss, _ = sess.run([ model.cate1_probs, model.cate2_probs, model.total_loss, model.optimizer ], feed) train_cate1_perr = eval_util.calculate_precision_at_equal_recall_rate( cate1_probs, batch_cate1_label_multiHot) train_cate1_gap = eval_util.calculate_gap( cate1_probs, batch_cate1_label_multiHot) train_cate2_perr = eval_util.calculate_precision_at_equal_recall_rate( cate2_probs, batch_cate2_label_multiHot) train_cate2_gap = eval_util.calculate_gap( cate2_probs, batch_cate2_label_multiHot) train_cate1_perr_list.append(train_cate1_perr) train_cate1_gap_list.append(train_cate1_gap) train_cate2_perr_list.append(train_cate2_perr) train_cate2_gap_list.append(train_cate2_gap) train_total_loss_list.append(total_loss) if model.global_step.eval() % args.report_freq == 0: print("report_freq: ", args.report_freq) print( 'cate1_train: Step:{} ; aver_train_cate1_perr:{} ; aver_train_cate1_gap_list:{} ; aver_total_loss:{}' .format( model.global_step.eval(), 1.0 * np.sum(train_cate1_perr_list) / len(train_cate1_perr_list), 1.0 * np.sum(train_cate1_gap_list) / len(train_cate1_gap_list), 1.0 * np.sum(train_total_loss_list) / len(train_total_loss_list))) print( 'cate2_train: Step:{} ; aver_train_cate2_perr:{} ; aver_train_cate2_gap_list:{} ; aver_total_loss:{}' .format( model.global_step.eval(), 1.0 * np.sum(train_cate2_perr_list) / len(train_cate2_perr_list), 1.0 * np.sum(train_cate2_gap_list) / len(train_cate2_gap_list), 1.0 * np.sum(train_total_loss_list) / len(train_total_loss_list))) train_cate1_perr_list = [] train_cate1_gap_list = [] train_cate2_perr_list = [] train_cate2_gap_list = [] train_total_loss_list = [] if model.global_step.eval( ) > 1 and model.global_step.eval() % args.valid_freq == 0: # 统计验证集的准确率 print("valid infer is process 111!") print('model.global_step.eval(): ', model.global_step.eval()) valid_cate2_gap_aver_loss = eval() # 保存当前验证集合准确率最高的模型 if valid_cate2_gap_aver_loss > valid_max_accuracy: print("save the model, step= : ", model.global_step.eval()) valid_max_accuracy = valid_cate2_gap_aver_loss checkpoint_path = os.path.join(args.model_dir, 'model.ckpt') saver.save(sess=sess, save_path=checkpoint_path, global_step=model.global_step.eval()) except tf.errors.OutOfRangeError: print("train processing is finished!")