Esempio n. 1
0
def validate(best_acc, epoch, Vis=None):
    acc_metrics = Seg_metrics(num_classes=2)
    global best_acc_epoch
    global base_path
    global model
    model.eval()
    for cnt, (x, y, image_label) in enumerate(val_loader):
        pre = model(x.to(opt.device))
        pre_y = torch.argmax(pre, dim=1)
        acc_metrics.add_batch(y.cpu(), pre_y.cpu())

    acc = acc_metrics.pixelAccuracy()
    recall = acc_metrics.classRecall()

    cur_acc = round(acc * 100, 2)
    acc_all.append(cur_acc)

    if cur_acc > best_acc:
        best_acc = cur_acc
        best_acc_epoch = epoch
        torch.save(model.state_dict(), 'checkpoints/network_state/acc{}_model.pth'.format(best_acc))
        print('save best_acc_model.pth successfully in the {} epoch!'.format(epoch))

    text_note_acc = "The best_acc gens in the {}_epoch,the best acc is {}". \
        format(best_acc_epoch, best_acc)
    text_note_recall = "the recall is {}".format(round(recall, 2))

    # 最优acc、iou保存路径提示
    Vis.writer.add_text(tag="note", text_string=text_note_acc + "||" + text_note_recall,
                        global_step=epoch)
    Vis.visual_data_curve(name="acc", data=cur_acc, data_index=epoch)
    Vis.visual_data_curve(name="recall", data=recall, data_index=epoch)
    print("\n epoch:{}-acc:{}--recall:{}".format(epoch, cur_acc, recall))
    return best_acc
Esempio n. 2
0
def main():
    # tensorboard 可视化
    TIMESTAMP = "{0:%Y-%m-%dII%H-%M-%S/}".format(datetime.now())
    log_dir = base_path + '/checkpoints/vis_log/' + TIMESTAMP
    print("The log save in {}".format(log_dir))
    Vis = VisualBoard(log_dir)
    best_acc = 0
    global loss_all
    global loss_mean
    global model
    for epoch in range(start_epoch, opt.epochs):
        model.train()
        for cnt, (x, y, image_label) in enumerate(train_loader):
            x = x.to(opt.device)
            y = y.to(opt.device)

            pre = model(x)
            loss = criterion(pre, y.long())

            # 记录loss
            loss_all.append(loss)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            sys.stdout.write('\r epoch:{}-batch:{}-loss:{}'.format(epoch, cnt, loss))
            sys.stdout.flush()

        # 计算每一轮的loss
        b_loss = sum(loss_all)/len(loss_all)
        loss_mean.append(b_loss)
        loss_all = []

        # 可视化loss曲线
        Vis.visual_data_curve(name="loss", data=b_loss, data_index=epoch)

        if epoch % opt.epoch_interval == opt.epoch_interval - 1:
            network_state = {'model': model.state_dict(),
                             'optimizer': optimizer.state_dict(),
                             'epoch': epoch}
            torch.save(network_state, base_path + '/checkpoints/network_state/network_epo{}.pth'.format(epoch))
            print('\n save model.pth successfully!')
        # 验证模式下,关闭梯度回传以及冻结BN层,降低占用内存空间
        with torch.no_grad():
            if epoch % opt.val_epoch == opt.val_epoch - 1:
                model.eval()
                # 验证阶段,每一次返回最优acc,并保存最优acc的模型参数,同时在tensorboard上可视化recall、acc曲线
                best_acc = validate(best_acc, epoch, Vis=Vis)
                # 可视化训练集的训练效果
                acc_metrics = Seg_metrics(num_classes=2)
                for cnt, (x, y, image_label) in enumerate(train_loader):
                    pre = model(x.to(opt.device))
                    pre_y = torch.argmax(pre, dim=1)
                    acc_metrics.add_batch(y.cpu(), pre_y.cpu())
                train_acc = acc_metrics.pixelAccuracy()
                train_recall = acc_metrics.classRecall()
                print("训练集精度为:{},召回率为:{}".format(round(train_acc*100, 2), round(train_recall*100, 2)))
    Vis.visual_close()