Exemple #1
0
def validate(best_acc, epoch, Vis=None):
    acc_metrics = Seg_metrics(num_classes=2)
    global best_acc_epoch
    global base_path
    global model
    model.eval()
    for cnt, (x, y, image_label) in enumerate(val_loader):
        pre = model(x.to(opt.device))
        pre_y = torch.argmax(pre, dim=1)
        acc_metrics.add_batch(y.cpu(), pre_y.cpu())

    acc = acc_metrics.pixelAccuracy()
    recall = acc_metrics.classRecall()

    cur_acc = round(acc * 100, 2)
    acc_all.append(cur_acc)

    if cur_acc > best_acc:
        best_acc = cur_acc
        best_acc_epoch = epoch
        torch.save(model.state_dict(), 'checkpoints/network_state/acc{}_model.pth'.format(best_acc))
        print('save best_acc_model.pth successfully in the {} epoch!'.format(epoch))

    text_note_acc = "The best_acc gens in the {}_epoch,the best acc is {}". \
        format(best_acc_epoch, best_acc)
    text_note_recall = "the recall is {}".format(round(recall, 2))

    # 最优acc、iou保存路径提示
    Vis.writer.add_text(tag="note", text_string=text_note_acc + "||" + text_note_recall,
                        global_step=epoch)
    Vis.visual_data_curve(name="acc", data=cur_acc, data_index=epoch)
    Vis.visual_data_curve(name="recall", data=recall, data_index=epoch)
    print("\n epoch:{}-acc:{}--recall:{}".format(epoch, cur_acc, recall))
    return best_acc
Exemple #2
0
def main():
    # tensorboard 可视化
    TIMESTAMP = "{0:%Y-%m-%dII%H-%M-%S/}".format(datetime.now())
    log_dir = base_path + '/checkpoints/vis_log/' + TIMESTAMP
    print("The log save in {}".format(log_dir))
    Vis = VisualBoard(log_dir)
    best_acc = 0
    global loss_all
    global loss_mean
    global model
    for epoch in range(start_epoch, opt.epochs):
        model.train()
        for cnt, (x, y, image_label) in enumerate(train_loader):
            x = x.to(opt.device)
            y = y.to(opt.device)

            pre = model(x)
            loss = criterion(pre, y.long())

            # 记录loss
            loss_all.append(loss)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            sys.stdout.write('\r epoch:{}-batch:{}-loss:{}'.format(epoch, cnt, loss))
            sys.stdout.flush()

        # 计算每一轮的loss
        b_loss = sum(loss_all)/len(loss_all)
        loss_mean.append(b_loss)
        loss_all = []

        # 可视化loss曲线
        Vis.visual_data_curve(name="loss", data=b_loss, data_index=epoch)

        if epoch % opt.epoch_interval == opt.epoch_interval - 1:
            network_state = {'model': model.state_dict(),
                             'optimizer': optimizer.state_dict(),
                             'epoch': epoch}
            torch.save(network_state, base_path + '/checkpoints/network_state/network_epo{}.pth'.format(epoch))
            print('\n save model.pth successfully!')
        # 验证模式下,关闭梯度回传以及冻结BN层,降低占用内存空间
        with torch.no_grad():
            if epoch % opt.val_epoch == opt.val_epoch - 1:
                model.eval()
                # 验证阶段,每一次返回最优acc,并保存最优acc的模型参数,同时在tensorboard上可视化recall、acc曲线
                best_acc = validate(best_acc, epoch, Vis=Vis)
                # 可视化训练集的训练效果
                acc_metrics = Seg_metrics(num_classes=2)
                for cnt, (x, y, image_label) in enumerate(train_loader):
                    pre = model(x.to(opt.device))
                    pre_y = torch.argmax(pre, dim=1)
                    acc_metrics.add_batch(y.cpu(), pre_y.cpu())
                train_acc = acc_metrics.pixelAccuracy()
                train_recall = acc_metrics.classRecall()
                print("训练集精度为:{},召回率为:{}".format(round(train_acc*100, 2), round(train_recall*100, 2)))
    Vis.visual_close()
    # y1 = np.max(np.array(out_cv))
    # 在阈值分割的基础上根据连通区域的大小进行分类
    y1 = 1 if max_area > 200 else 0

    # 利用混淆矩阵计算acc
    # y1 = np.max(out_cv)
    if y1 == 1:
        y1 = np.array([1])
    else:
        y1 = np.array([0])
    if image_label == 1:
        label = np.array([1])
    else:
        label = np.array([0])

    metrics.add_batch(label, y1)
    acc = metrics.pixelAccuracy()
    # 根据混淆矩阵对分类结果进行保存,对应TP、 FP、 FN、 TN
    confusionMatrix = metrics.confusionMatrix
    metrics.reset()
    confusionMatrix = confusionMatrix.reshape(1, -1)
    image_save_path = base_path + '/checkpoints/test_result/{}'.format(
        c_name[np.argmax(confusionMatrix, axis=1)[0]])
    if not os.path.exists(image_save_path):
        os.makedirs(image_save_path)
    image_save_path = os.path.join(image_save_path,
                                   image_path[0].split("/")[-1])
    cv2.imwrite(image_save_path, concat_image)
    result_acc.append(acc)

result = 0
def validate(best_iou, best_acc, epoch, Vis=None, best_model_flag=False):
    acc_metrics = Seg_metrics(num_classes=2)
    iou_metrics = Seg_metrics(num_classes=2)
    global best_acc_epoch
    global best_iou_epoch
    global base_path
    global model
    model.eval()
    for cnt, (x, y, _, image_label) in enumerate(val_loader):
        output = model(x.to(opt.device))
        output = F.softmax(output, dim=1)
        output = output.squeeze(dim=0)
        out = output[1]

        out_image = trans(out.cpu())  # cpu
        label_image = trans(y[0])
        out_image = np.where(np.array(out_image) > 128, 1, 0)
        label_image = np.where(np.array(label_image) > 254, 1, 0)

        out_cv1 = out.detach().cpu()
        out_cv1 = np.uint8(out_cv1 * 255)
        _, out_cv = cv2.threshold(out_cv1, 128, 255, cv2.THRESH_BINARY)
        max_area = cal_max_area(out_cv)

        # 只根据阈值对分割图进行分类
        # y1 = np.max(np.array(out_cv))
        # 在阈值分割的基础上根据连通区域的大小进行分类
        y1 = 1 if max_area > 0 else 0
        # y1 = np.max(out_image)
        if y1 == 1:
            y1 = np.array([1])
        else:
            y1 = np.array([0])
        if image_label == 1:
            label = np.array([1])
        else:
            label = np.array([0])
        acc_metrics.add_batch(label, y1)

        # cal mean_iou
        iou_metrics.add_batch(label_image.reshape(1, -1),
                              out_image.reshape(1, -1))

    acc = acc_metrics.pixelAccuracy()
    recall = acc_metrics.TPR()

    cur_acc = round(acc * 100, 2)
    acc_all.append(cur_acc)

    iou = iou_metrics.meanIntersectionOverUnion()
    cur_iou = round(iou * 100, 2)
    iou_all.append(cur_iou)

    if cur_iou > best_iou:
        best_iou = cur_iou
        best_iou_epoch = epoch
        torch.save(model.state_dict(),
                   'checkpoints/network_state/best_iou_model.pth')
        print('\nsave best_iou_model.pth successfully in the {} epoch!'.format(
            epoch))

    if cur_acc > best_acc:
        best_model_flag = True
        best_acc = cur_acc
        best_acc_epoch = epoch

        # 避免多次保存相同epoch的pth文件
        remove_old_pths = glob(
            "checkpoints/network_state/epoch{}*".format(epoch))
        for remove_old_pth in remove_old_pths:
            if os.path.exists(remove_old_pth):
                os.remove(remove_old_pth)

        torch.save(
            model.state_dict(),
            'checkpoints/network_state/epoch{}_acc{}_model.pth'.format(
                epoch, best_acc))

        print('\nsave best_acc_model.pth successfully in the {} epoch!'.format(
            epoch))

    text_note_iou = "The best_iou gens in the {}_epoch, the best iou is {}". \
        format(best_iou_epoch, best_iou)
    text_note_acc = "The best_acc gens in the {}_epoch,the best acc is {}". \
        format(best_acc_epoch, best_acc)
    text_note_recall = "the recall is {}".format(round(recall, 2))

    # 最优acc、iou保存路径提示
    Vis.writer.add_text(tag="note",
                        text_string=text_note_iou + "||" + text_note_acc +
                        "," + text_note_recall,
                        global_step=epoch)
    Vis.visual_data_curve(name="acc", data=cur_acc, data_index=epoch)
    Vis.visual_data_curve(name="iou", data=cur_iou, data_index=epoch)
    print("\nepoch:{}-val_acc:{}--val_iou:{}".format(epoch, cur_acc, cur_iou))
    return best_iou, best_acc, best_model_flag
Exemple #5
0
# 清空文件夹
shutil.rmtree(base_path + '/checkpoints/test_result')
os.mkdir(base_path + '/checkpoints/test_result')

# cam可视化
cam = GradCAM(model=model, cam_layer="feature.7.1.bn2")
output_path = 'checkpoints/cam_output'
if not os.path.exists(output_path):
    os.makedirs(output_path)
cam_display(cam=cam, visual_data=test_loader)
model.to(opt.device)

for i, (x, y, image_path) in enumerate(tqdm(test_loader)):
    pre = model(x.to(opt.device))
    pre_y = torch.argmax(pre, dim=1)
    metrics.add_batch(pre_y.cpu(), y.cpu())
    acc = metrics.pixelAccuracy()
    # 根据混淆矩阵对分类结果进行保存,对应TP、 FP、 FN、 TN
    confusionMatrix = metrics.confusionMatrix
    metrics.reset()
    confusionMatrix = confusionMatrix.reshape(1, -1)
    image_save_path = base_path + '/checkpoints/test_result/{}'.format(c_name[np.argmax(confusionMatrix, axis=1)[0]])
    if not os.path.exists(image_save_path):
        os.makedirs(image_save_path)
    image_save_path = os.path.join(image_save_path, image_path[0].split("/")[-1])
    image = cv2.imread(image_path[0])
    cv2.imwrite(image_save_path, image)
    result_acc.append(acc)

result = 0
for acc in result_acc: