def val(model1, model2, dataParser, epoch): # 读取数据的迭代器 val_epoch = len(dataParser) # 变量保存 batch_time = Averagvalue() data_time = Averagvalue() losses = Averagvalue() loss_stage1 = Averagvalue() loss_stage2 = Averagvalue() f1_value_stage1 = Averagvalue() acc_value_stage1 = Averagvalue() recall_value_stage1 = Averagvalue() precision_value_stage1 = Averagvalue() f1_value_stage2 = Averagvalue() acc_value_stage2 = Averagvalue() recall_value_stage2 = Averagvalue() precision_value_stage2 = Averagvalue() map8_loss_value = [ Averagvalue(), Averagvalue(), Averagvalue(), Averagvalue(), Averagvalue(), Averagvalue(), Averagvalue(), Averagvalue(), Averagvalue(), Averagvalue() ] # switch to train mode model1.eval() model2.eval() end = time.time() for batch_index, input_data in enumerate(dataParser): # 读取数据的时间 data_time.update(time.time() - end) # check_4dim_img_pair(input_data['tamper_image'],input_data['gt_band']) # 准备输入数据 images = input_data['tamper_image'].cuda() labels_band = input_data['gt_band'].cuda() labels_dou_edge = input_data['gt_dou_edge'].cuda() relation_map = input_data['relation_map'] if torch.cuda.is_available(): loss_8t = torch.zeros(()).cuda() else: loss_8t = torch.zeros(()) with torch.set_grad_enabled(False): images.requires_grad = False # 网络输出 one_stage_outputs = model1(images) rgb_pred = images * one_stage_outputs[0] rgb_pred_rgb = torch.cat((rgb_pred, images), 1) two_stage_outputs = model2(rgb_pred_rgb, one_stage_outputs[1], one_stage_outputs[2], one_stage_outputs[3]) """""" """""" """""" """""" """""" " Loss 函数 " """""" """""" """""" """""" """""" ########################################## # deal with one stage issue # 建立loss loss_stage_1 = wce_dice_huber_loss(one_stage_outputs[0], labels_band) ############################################## # deal with two stage issues _loss_stage_2 = wce_dice_huber_loss(two_stage_outputs[0], labels_dou_edge) * 12 for c_index, c in enumerate(two_stage_outputs[1:9]): one_loss_t = map8_loss_ce(c, relation_map[c_index].cuda()) loss_8t += one_loss_t map8_loss_value[c_index].update(one_loss_t.item()) _loss_stage_2 += loss_8t loss_stage_2 = _loss_stage_2 / 20 loss = (loss_stage_1 + loss_stage_2) / 2 ####################################### # 总的LOSS writer.add_scalar('val_stage_one_loss', loss_stage_1.item(), global_step=epoch * val_epoch + batch_index) writer.add_scalar('val_stage_two_fuse_loss', loss_stage_2.item(), global_step=epoch * val_epoch + batch_index) writer.add_scalar('val_fuse_loss_per_epoch', loss.item(), global_step=epoch * val_epoch + batch_index) ########################################## # 将各种数据记录到专门的对象中 losses.update(loss.item()) loss_stage1.update(loss_stage_1.item()) loss_stage2.update(loss_stage_2.item()) batch_time.update(time.time() - end) end = time.time() # 评价指标 f1score_stage2 = my_f1_score(two_stage_outputs[0], labels_dou_edge) precisionscore_stage2 = my_precision_score(two_stage_outputs[0], labels_dou_edge) accscore_stage2 = my_acc_score(two_stage_outputs[0], labels_dou_edge) recallscore_stage2 = my_recall_score(two_stage_outputs[0], labels_dou_edge) f1score_stage1 = my_f1_score(one_stage_outputs[0], labels_band) precisionscore_stage1 = my_precision_score(one_stage_outputs[0], labels_band) accscore_stage1 = my_acc_score(one_stage_outputs[0], labels_band) recallscore_stage1 = my_recall_score(one_stage_outputs[0], labels_band) writer.add_scalar('val_f1_score_stage1', f1score_stage1, global_step=epoch * val_epoch + batch_index) writer.add_scalar('val_precision_score_stage1', precisionscore_stage1, global_step=epoch * val_epoch + batch_index) writer.add_scalar('val_acc_score_stage1', accscore_stage1, global_step=epoch * val_epoch + batch_index) writer.add_scalar('val_recall_score_stage1', recallscore_stage1, global_step=epoch * val_epoch + batch_index) writer.add_scalar('val_f1_score_stage2', f1score_stage2, global_step=epoch * val_epoch + batch_index) writer.add_scalar('val_precision_score_stage2', precisionscore_stage2, global_step=epoch * val_epoch + batch_index) writer.add_scalar('val_acc_score_stage2', accscore_stage2, global_step=epoch * val_epoch + batch_index) writer.add_scalar('val_recall_score_stage2', recallscore_stage2, global_step=epoch * val_epoch + batch_index) ################################ f1_value_stage1.update(f1score_stage1) precision_value_stage1.update(precisionscore_stage1) acc_value_stage1.update(accscore_stage1) recall_value_stage1.update(recallscore_stage1) f1_value_stage2.update(f1score_stage2) precision_value_stage2.update(precisionscore_stage2) acc_value_stage2.update(accscore_stage2) recall_value_stage2.update(recallscore_stage2) if batch_index % args.print_freq == 0: info = 'Epoch: [{0}/{1}][{2}/{3}] '.format(epoch, args.maxepoch, batch_index, val_epoch) + \ 'Time {batch_time.val:.3f} (avg:{batch_time.avg:.3f}) '.format(batch_time=batch_time) + \ '两阶段总Loss {loss.val:f} (avg:{loss.avg:f}) '.format(loss=losses) + \ '第一阶段Loss {loss.val:f} (avg:{loss.avg:f}) '.format(loss=loss_stage1) + \ '第二阶段Loss {loss.val:f} (avg:{loss.avg:f}) '.format(loss=loss_stage2) + \ '第一阶段:f1_score {f1.val:f} (avg:{f1.avg:f}) '.format(f1=f1_value_stage1) + \ '第一阶段:precision_score: {precision.val:f} (avg:{precision.avg:f}) '.format( precision=precision_value_stage1) + \ '第一阶段:acc_score {acc.val:f} (avg:{acc.avg:f})'.format(acc=acc_value_stage1) + \ '第一阶段:recall_score {recall.val:f} (avg:{recall.avg:f})'.format(recall=recall_value_stage1) + \ '第二阶段:f1_score {f1.val:f} (avg:{f1.avg:f}) '.format(f1=f1_value_stage2) + \ '第二阶段:precision_score: {precision.val:f} (avg:{precision.avg:f}) '.format( precision=precision_value_stage2) + \ '第二阶段:acc_score {acc.val:f} (avg:{acc.avg:f})'.format(acc=acc_value_stage2) + \ '第二阶段:recall_score {recall.val:f} (avg:{recall.avg:f})'.format(recall=recall_value_stage2) print(info) if batch_index >= val_epoch: break return { 'loss_avg': losses.avg, 'f1_avg_stage1': f1_value_stage1.avg, 'precision_avg_stage1': precision_value_stage1.avg, 'accuracy_avg_stage1': acc_value_stage1.avg, 'recall_avg_stage1': recall_value_stage1.avg, 'f1_avg_stage2': f1_value_stage2.avg, 'precision_avg_stage2': precision_value_stage2.avg, 'accuracy_avg_stage2': acc_value_stage2.avg, 'recall_avg_stage2': recall_value_stage2.avg, 'map8_loss': [map8_loss.avg for map8_loss in map8_loss_value], }
def train(model1, model2, optimizer1, optimizer2, dataParser, epoch): # 读取数据的迭代器 train_epoch = len(dataParser) # 变量保存 batch_time = Averagvalue() data_time = Averagvalue() losses = Averagvalue() loss_stage1 = Averagvalue() loss_stage2 = Averagvalue() f1_value_stage1 = Averagvalue() acc_value_stage1 = Averagvalue() recall_value_stage1 = Averagvalue() precision_value_stage1 = Averagvalue() f1_value_stage2 = Averagvalue() acc_value_stage2 = Averagvalue() recall_value_stage2 = Averagvalue() precision_value_stage2 = Averagvalue() map8_loss_value = [ Averagvalue(), Averagvalue(), Averagvalue(), Averagvalue(), Averagvalue(), Averagvalue(), Averagvalue(), Averagvalue(), Averagvalue(), Averagvalue() ] # switch to train mode model1.train() model2.train() end = time.time() for batch_index, input_data in enumerate(dataParser): # 读取数据的时间 data_time.update(time.time() - end) # check_4dim_img_pair(input_data['tamper_image'],input_data['gt_band']) # 准备输入数据 images = input_data['tamper_image'].cuda() labels_band = input_data['gt_band'].cuda() labels_dou_edge = input_data['gt_dou_edge'].cuda() relation_map = input_data['relation_map'] if torch.cuda.is_available(): loss_8t = torch.zeros(()).cuda() else: loss_8t = torch.zeros(()) with torch.set_grad_enabled(True): images.requires_grad = True optimizer1.zero_grad() optimizer2.zero_grad() if images.shape[1] != 3 or images.shape[2] != 320: continue # 网络输出 try: one_stage_outputs = model1(images) except Exception as e: print(e) print(images.shape) continue rgb_pred = images * one_stage_outputs[0] rgb_pred_rgb = torch.cat((rgb_pred, images), 1) two_stage_outputs = model2(rgb_pred_rgb, one_stage_outputs[1], one_stage_outputs[2], one_stage_outputs[3]) """""" """""" """""" """""" """""" " Loss 函数 " """""" """""" """""" """""" """""" ########################################## # deal with one stage issue # 建立loss loss_stage_1 = wce_dice_huber_loss(one_stage_outputs[0], labels_band) ############################################## # deal with two stage issues loss_stage_2 = wce_dice_huber_loss(two_stage_outputs[0], labels_dou_edge) for c_index, c in enumerate(two_stage_outputs[1:9]): one_loss_t = map8_loss_ce(c, relation_map[c_index].cuda()) loss_8t += one_loss_t # print(one_loss_t) map8_loss_value[c_index].update(one_loss_t.item()) # print(loss_stage_2) # print(map8_loss_value) loss = loss_stage_2 + loss_8t * 10 ####################################### # 总的LOSS # print(type(loss_stage_2.item())) writer.add_scalars('loss_gather', { 'stage_one_loss': loss_stage_1.item(), 'stage_two_fuse_loss': loss_stage_2.item() }, global_step=epoch * train_epoch + batch_index) ########################################## loss.backward() optimizer1.step() optimizer2.step() # 将各种数据记录到专门的对象中 losses.update(loss.item()) loss_stage1.update(loss_stage_1.item()) loss_stage2.update(loss_stage_2.item()) batch_time.update(time.time() - end) end = time.time() # 评价指标 # f1score_stage2 = my_f1_score(two_stage_outputs[0], labels_dou_edge) # precisionscore_stage2 = my_precision_score(two_stage_outputs[0], labels_dou_edge) # accscore_stage2 = my_acc_score(two_stage_outputs[0], labels_dou_edge) # recallscore_stage2 = my_recall_score(two_stage_outputs[0], labels_dou_edge) f1score_stage2 = 1 precisionscore_stage2 = 1 accscore_stage2 = 1 recallscore_stage2 = 1 # # f1score_stage1 = my_f1_score(one_stage_outputs[0], labels_band) # precisionscore_stage1 = my_precision_score(one_stage_outputs[0], labels_band) # accscore_stage1 = my_acc_score(one_stage_outputs[0], labels_band) # recallscore_stage1 = my_recall_score(one_stage_outputs[0], labels_band) f1score_stage1 = 1 precisionscore_stage1 = 1 accscore_stage1 = 1 recallscore_stage1 = 1 writer.add_scalars('f1_score_stage', { 'stage1': f1score_stage1, 'stage2': f1score_stage2 }, global_step=epoch * train_epoch + batch_index) writer.add_scalars('precision_score_stage', { 'stage1': precisionscore_stage1, 'stage2': precisionscore_stage2 }, global_step=epoch * train_epoch + batch_index) writer.add_scalars('acc_score_stage', { 'stage1': accscore_stage1, 'stage2': accscore_stage2 }, global_step=epoch * train_epoch + batch_index) writer.add_scalars('recall_score_stage', { 'stage1': recallscore_stage1, 'stage2': recallscore_stage2 }, global_step=epoch * train_epoch + batch_index) ################################ f1_value_stage1.update(f1score_stage1) precision_value_stage1.update(precisionscore_stage1) acc_value_stage1.update(accscore_stage1) recall_value_stage1.update(recallscore_stage1) f1_value_stage2.update(f1score_stage2) precision_value_stage2.update(precisionscore_stage2) acc_value_stage2.update(accscore_stage2) recall_value_stage2.update(recallscore_stage2) if batch_index % args.print_freq == 0: info = 'Epoch: [{0}/{1}][{2}/{3}] '.format(epoch, args.maxepoch, batch_index, train_epoch) + \ 'Time {batch_time.val:.3f} (avg:{batch_time.avg:.3f}) '.format(batch_time=batch_time) + \ '两阶段总Loss {loss.val:f} (avg:{loss.avg:f}) '.format(loss=losses) + \ '第一阶段Loss {loss.val:f} (avg:{loss.avg:f}) '.format(loss=loss_stage1) + \ '第二阶段Loss {loss.val:f} (avg:{loss.avg:f}) '.format(loss=loss_stage2) + \ '第一阶段:f1_score {f1.val:f} (avg:{f1.avg:f}) '.format(f1=f1_value_stage1) + \ '第一阶段:precision_score: {precision.val:f} (avg:{precision.avg:f}) '.format( precision=precision_value_stage1) + \ '第一阶段:acc_score {acc.val:f} (avg:{acc.avg:f})'.format(acc=acc_value_stage1) + \ '第一阶段:recall_score {recall.val:f} (avg:{recall.avg:f})'.format(recall=recall_value_stage1) + \ '第二阶段:f1_score {f1.val:f} (avg:{f1.avg:f}) '.format(f1=f1_value_stage2) + \ '第二阶段:precision_score: {precision.val:f} (avg:{precision.avg:f}) '.format( precision=precision_value_stage2) + \ '第二阶段:acc_score {acc.val:f} (avg:{acc.avg:f})'.format(acc=acc_value_stage2) + \ '第二阶段:recall_score {recall.val:f} (avg:{recall.avg:f})'.format(recall=recall_value_stage2) print(info) if batch_index >= train_epoch: break return { 'loss_avg': losses.avg, 'f1_avg_stage1': f1_value_stage1.avg, 'precision_avg_stage1': precision_value_stage1.avg, 'accuracy_avg_stage1': acc_value_stage1.avg, 'recall_avg_stage1': recall_value_stage1.avg, 'f1_avg_stage2': f1_value_stage2.avg, 'precision_avg_stage2': precision_value_stage2.avg, 'accuracy_avg_stage2': acc_value_stage2.avg, 'recall_avg_stage2': recall_value_stage2.avg, 'map8_loss': [map8_loss.avg for map8_loss in map8_loss_value], }