def validate(best_acc, epoch, Vis=None): acc_metrics = Seg_metrics(num_classes=2) global best_acc_epoch global base_path global model model.eval() for cnt, (x, y, image_label) in enumerate(val_loader): pre = model(x.to(opt.device)) pre_y = torch.argmax(pre, dim=1) acc_metrics.add_batch(y.cpu(), pre_y.cpu()) acc = acc_metrics.pixelAccuracy() recall = acc_metrics.classRecall() cur_acc = round(acc * 100, 2) acc_all.append(cur_acc) if cur_acc > best_acc: best_acc = cur_acc best_acc_epoch = epoch torch.save(model.state_dict(), 'checkpoints/network_state/acc{}_model.pth'.format(best_acc)) print('save best_acc_model.pth successfully in the {} epoch!'.format(epoch)) text_note_acc = "The best_acc gens in the {}_epoch,the best acc is {}". \ format(best_acc_epoch, best_acc) text_note_recall = "the recall is {}".format(round(recall, 2)) # 最优acc、iou保存路径提示 Vis.writer.add_text(tag="note", text_string=text_note_acc + "||" + text_note_recall, global_step=epoch) Vis.visual_data_curve(name="acc", data=cur_acc, data_index=epoch) Vis.visual_data_curve(name="recall", data=recall, data_index=epoch) print("\n epoch:{}-acc:{}--recall:{}".format(epoch, cur_acc, recall)) return best_acc
def main(): # tensorboard 可视化 TIMESTAMP = "{0:%Y-%m-%dII%H-%M-%S/}".format(datetime.now()) log_dir = base_path + '/checkpoints/vis_log/' + TIMESTAMP print("The log save in {}".format(log_dir)) Vis = VisualBoard(log_dir) best_acc = 0 global loss_all global loss_mean global model for epoch in range(start_epoch, opt.epochs): model.train() for cnt, (x, y, image_label) in enumerate(train_loader): x = x.to(opt.device) y = y.to(opt.device) pre = model(x) loss = criterion(pre, y.long()) # 记录loss loss_all.append(loss) optimizer.zero_grad() loss.backward() optimizer.step() sys.stdout.write('\r epoch:{}-batch:{}-loss:{}'.format(epoch, cnt, loss)) sys.stdout.flush() # 计算每一轮的loss b_loss = sum(loss_all)/len(loss_all) loss_mean.append(b_loss) loss_all = [] # 可视化loss曲线 Vis.visual_data_curve(name="loss", data=b_loss, data_index=epoch) if epoch % opt.epoch_interval == opt.epoch_interval - 1: network_state = {'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch} torch.save(network_state, base_path + '/checkpoints/network_state/network_epo{}.pth'.format(epoch)) print('\n save model.pth successfully!') # 验证模式下,关闭梯度回传以及冻结BN层,降低占用内存空间 with torch.no_grad(): if epoch % opt.val_epoch == opt.val_epoch - 1: model.eval() # 验证阶段,每一次返回最优acc,并保存最优acc的模型参数,同时在tensorboard上可视化recall、acc曲线 best_acc = validate(best_acc, epoch, Vis=Vis) # 可视化训练集的训练效果 acc_metrics = Seg_metrics(num_classes=2) for cnt, (x, y, image_label) in enumerate(train_loader): pre = model(x.to(opt.device)) pre_y = torch.argmax(pre, dim=1) acc_metrics.add_batch(y.cpu(), pre_y.cpu()) train_acc = acc_metrics.pixelAccuracy() train_recall = acc_metrics.classRecall() print("训练集精度为:{},召回率为:{}".format(round(train_acc*100, 2), round(train_recall*100, 2))) Vis.visual_close()
test_loader = DataLoader(test_data, batch_size=opt.batch_size, shuffle=False) # model model = Res18_UNet(n_classes=2, layer=4).to(opt.device) # device_ids = [0] # model = torch.nn.DataParallel(model, device_ids=device_ids) # model = model.cuda(device=device_ids[0]) model.eval() checkpoints = torch.load(os.path.join(base_path, opt.state_path)) # model.load_state_dict(checkpoints["model"]) model.load_state_dict(checkpoints) trans = tran.ToPILImage() metrics = Seg_metrics(num_classes=2) result_acc = [] result_iou = [] c_name = ["TP", "FN", "FP", "TN"] # 清空文件夹 shutil.rmtree(base_path + '/checkpoints/test_result') os.mkdir(base_path + '/checkpoints/test_result') for i, (x, y, image_path, image_label) in enumerate(tqdm(test_loader)): output = model(x.to(opt.device)) output = F.softmax(output, dim=1) output = output.squeeze(dim=0) out = output[1].unsqueeze(0) # 反归一化
def validate(best_iou, best_acc, epoch, Vis=None, best_model_flag=False): acc_metrics = Seg_metrics(num_classes=2) iou_metrics = Seg_metrics(num_classes=2) global best_acc_epoch global best_iou_epoch global base_path global model model.eval() for cnt, (x, y, _, image_label) in enumerate(val_loader): output = model(x.to(opt.device)) output = F.softmax(output, dim=1) output = output.squeeze(dim=0) out = output[1] out_image = trans(out.cpu()) # cpu label_image = trans(y[0]) out_image = np.where(np.array(out_image) > 128, 1, 0) label_image = np.where(np.array(label_image) > 254, 1, 0) out_cv1 = out.detach().cpu() out_cv1 = np.uint8(out_cv1 * 255) _, out_cv = cv2.threshold(out_cv1, 128, 255, cv2.THRESH_BINARY) max_area = cal_max_area(out_cv) # 只根据阈值对分割图进行分类 # y1 = np.max(np.array(out_cv)) # 在阈值分割的基础上根据连通区域的大小进行分类 y1 = 1 if max_area > 0 else 0 # y1 = np.max(out_image) if y1 == 1: y1 = np.array([1]) else: y1 = np.array([0]) if image_label == 1: label = np.array([1]) else: label = np.array([0]) acc_metrics.add_batch(label, y1) # cal mean_iou iou_metrics.add_batch(label_image.reshape(1, -1), out_image.reshape(1, -1)) acc = acc_metrics.pixelAccuracy() recall = acc_metrics.TPR() cur_acc = round(acc * 100, 2) acc_all.append(cur_acc) iou = iou_metrics.meanIntersectionOverUnion() cur_iou = round(iou * 100, 2) iou_all.append(cur_iou) if cur_iou > best_iou: best_iou = cur_iou best_iou_epoch = epoch torch.save(model.state_dict(), 'checkpoints/network_state/best_iou_model.pth') print('\nsave best_iou_model.pth successfully in the {} epoch!'.format( epoch)) if cur_acc > best_acc: best_model_flag = True best_acc = cur_acc best_acc_epoch = epoch # 避免多次保存相同epoch的pth文件 remove_old_pths = glob( "checkpoints/network_state/epoch{}*".format(epoch)) for remove_old_pth in remove_old_pths: if os.path.exists(remove_old_pth): os.remove(remove_old_pth) torch.save( model.state_dict(), 'checkpoints/network_state/epoch{}_acc{}_model.pth'.format( epoch, best_acc)) print('\nsave best_acc_model.pth successfully in the {} epoch!'.format( epoch)) text_note_iou = "The best_iou gens in the {}_epoch, the best iou is {}". \ format(best_iou_epoch, best_iou) text_note_acc = "The best_acc gens in the {}_epoch,the best acc is {}". \ format(best_acc_epoch, best_acc) text_note_recall = "the recall is {}".format(round(recall, 2)) # 最优acc、iou保存路径提示 Vis.writer.add_text(tag="note", text_string=text_note_iou + "||" + text_note_acc + "," + text_note_recall, global_step=epoch) Vis.visual_data_curve(name="acc", data=cur_acc, data_index=epoch) Vis.visual_data_curve(name="iou", data=cur_iou, data_index=epoch) print("\nepoch:{}-val_acc:{}--val_iou:{}".format(epoch, cur_acc, cur_iou)) return best_iou, best_acc, best_model_flag