Exemple #1
0
def validate(args):
    if args.vis:
        vis = visdom.Visdom()
    if args.dataset_path == '':
        HOME_PATH = os.path.expanduser('~')
        local_path = os.path.join(HOME_PATH, 'Data/CamVid')
    else:
        local_path = args.dataset_path
    dst = camvidLoader(local_path, is_transform=True, split='val')
    valloader = torch.utils.data.DataLoader(dst, batch_size=1)

    # if os.path.isfile(args.validate_model):
    if args.validate_model != '':
        model = torch.load(args.validate_model)
    else:
        if args.structure == 'fcn32s':
            model = fcn(module_type='32s', n_classes=dst.n_classes)
        elif args.structure == 'fcn16s':
            model = fcn(module_type='16s', n_classes=dst.n_classes)
        elif args.structure == 'fcn8s':
            model = fcn(module_type='8s', n_classes=dst.n_classes)
        elif args.structure == 'ResNetDUC':
            model = ResNetDUC(n_classes=dst.n_classes)
        elif args.structure == 'segnet':
            model = segnet(n_classes=dst.n_classes)
        elif args.structure == 'ENet':
            model = ENet(n_classes=dst.n_classes)
        elif args.structure == 'drn_d_22':
            model = DRNSeg(model_name='drn_d_22', n_classes=dst.n_classes)
        elif args.structure == 'pspnet':
            model = pspnet(n_classes=dst.n_classes, use_aux=False)
        elif args.structure == 'erfnet':
            model = erfnet(n_classes=dst.n_classes)
        if args.validate_model_state_dict != '':
            try:
                model.load_state_dict(torch.load(args.validate_model_state_dict))
            except KeyError:
                print('missing key')
    model.eval()

    gts, preds = [], []
    for i, (imgs, labels) in enumerate(valloader):
        print(i)
        #  print(labels.shape)
        #  print(imgs.shape)
        # 将np变量转换为pytorch中的变量
        imgs = Variable(imgs)
        labels = Variable(labels)

        outputs = model(imgs)
        # 取axis=1中的最大值,outputs的shape为batch_size*n_classes*height*width,
        # 获取max后,返回两个数组,分别是最大值和相应的索引值,这里取索引值为label
        pred = outputs.data.max(1)[1].numpy()
        gt = labels.data.numpy()
        # print(pred.dtype)
        # print(gt.dtype)
        # print('pred.shape:', pred.shape)
        # print('gt.shape:', gt.shape)

        if args.vis and i % 50 == 0:
            img = imgs.data.numpy()[0]
            # print(img.shape)
            label_color = dst.decode_segmap(gt[0]).transpose(2, 0, 1)
            # print(label_color.shape)
            pred_label_color = dst.decode_segmap(pred[0]).transpose(2, 0, 1)
            # print(pred_label_color.shape)
            # try:
            #     win = 'label_color'
            #     vis.image(label_color, win=win)
            #     win = 'pred_label_color'
            #     vis.image(pred_label_color, win=win)
            # except ConnectionError:
            #     print('ConnectionError')


            if args.blend:
                img_hwc = img.transpose(1, 2, 0)
                img_hwc = img_hwc*255.0
                img_hwc += np.array([104.00699, 116.66877, 122.67892])
                img_hwc = np.array(img_hwc, dtype=np.uint8)
                # label_color_hwc = label_color.transpose(1, 2, 0)
                pred_label_color_hwc = pred_label_color.transpose(1, 2, 0)
                pred_label_color_hwc = np.array(pred_label_color_hwc, dtype=np.uint8)
                # print(img_hwc.dtype)
                # print(pred_label_color_hwc.dtype)
                label_blend = img_hwc * 0.5 + pred_label_color_hwc * 0.5
                label_blend = np.array(label_blend, dtype=np.uint8)
                misc.imsave('/tmp/label_blend.png', label_blend)

        for gt_, pred_ in zip(gt, pred):
            gts.append(gt_)
            preds.append(pred_)


    score, class_iou = scores(gts, preds, n_class=dst.n_classes)
    for k, v in score.items():
        print(k, v)

    for i in range(dst.n_classes):
        print(i, class_iou[i])
Exemple #2
0
def validate(args):
    init_time = str(int(time.time()))
    if args.vis:
        vis = visdom.Visdom()
    if args.dataset_path == '':
        HOME_PATH = os.path.expanduser('~')
        local_path = os.path.join(HOME_PATH, 'Data/CamVid')
    else:
        local_path = args.dataset_path
    local_path = os.path.expanduser(args.dataset_path)
    if args.dataset == 'CamVid':
        dst = camvidLoader(local_path,
                           is_transform=True,
                           split=args.dataset_type)
    elif args.dataset == 'CityScapes':
        dst = cityscapesLoader(local_path, is_transform=True)
    else:
        pass
    valloader = torch.utils.data.DataLoader(dst, batch_size=1)

    # if os.path.isfile(args.validate_model):
    if args.validate_model != '':
        model = torch.load(args.validate_model)
    else:
        if args.structure == 'fcn32s':
            model = fcn(module_type='32s',
                        n_classes=dst.n_classes,
                        pretrained=args.init_vgg16)
        elif args.structure == 'fcn16s':
            model = fcn(module_type='16s',
                        n_classes=dst.n_classes,
                        pretrained=args.init_vgg16)
        elif args.structure == 'fcn8s':
            model = fcn(module_type='8s',
                        n_classes=dst.n_classes,
                        pretrained=args.init_vgg16)
        elif args.structure == 'fcn_resnet18_32s':
            model = fcn_resnet18(module_type='32s',
                                 n_classes=dst.n_classes,
                                 pretrained=args.init_vgg16)
        elif args.structure == 'fcn_resnet18_16s':
            model = fcn_resnet18(module_type='16s',
                                 n_classes=dst.n_classes,
                                 pretrained=args.init_vgg16)
        elif args.structure == 'fcn_resnet18_8s':
            model = fcn_resnet18(module_type='8s',
                                 n_classes=dst.n_classes,
                                 pretrained=args.init_vgg16)
        elif args.structure == 'fcn_resnet34_32s':
            model = fcn_resnet34(module_type='32s',
                                 n_classes=dst.n_classes,
                                 pretrained=args.init_vgg16)
        elif args.structure == 'fcn_resnet34_16s':
            model = fcn_resnet34(module_type='16s',
                                 n_classes=dst.n_classes,
                                 pretrained=args.init_vgg16)
        elif args.structure == 'fcn_resnet34_8s':
            model = fcn_resnet34(module_type='8s',
                                 n_classes=dst.n_classes,
                                 pretrained=args.init_vgg16)
        elif args.structure == 'fcn_MobileNet_32s':
            model = fcn_MobileNet(module_type='32s',
                                  n_classes=dst.n_classes,
                                  pretrained=args.init_vgg16)
        elif args.structure == 'fcn_MobileNet_16s':
            model = fcn_MobileNet(module_type='16s',
                                  n_classes=dst.n_classes,
                                  pretrained=args.init_vgg16)
        elif args.structure == 'fcn_MobileNet_8s':
            model = fcn_MobileNet(module_type='8s',
                                  n_classes=dst.n_classes,
                                  pretrained=args.init_vgg16)
        elif args.structure == 'ResNetDUC':
            model = ResNetDUC(n_classes=dst.n_classes,
                              pretrained=args.init_vgg16)
        elif args.structure == 'ResNetDUCHDC':
            model = ResNetDUCHDC(n_classes=dst.n_classes,
                                 pretrained=args.init_vgg16)
        elif args.structure == 'segnet':
            model = segnet(n_classes=dst.n_classes, pretrained=args.init_vgg16)
        elif args.structure == 'segnet_vgg19':
            model = segnet_vgg19(n_classes=dst.n_classes,
                                 pretrained=args.init_vgg16)
        elif args.structure == 'segnet_unet':
            model = segnet_unet(n_classes=dst.n_classes,
                                pretrained=args.init_vgg16)
        elif args.structure == 'segnet_alignres':
            model = segnet_alignres(n_classes=dst.n_classes,
                                    pretrained=args.init_vgg16)
        elif args.structure == 'sqnet':
            model = sqnet(n_classes=dst.n_classes, pretrained=args.init_vgg16)
        elif args.structure == 'segnet_squeeze':
            model = segnet_squeeze(n_classes=dst.n_classes,
                                   pretrained=args.init_vgg16)
        elif args.structure == 'ENet':
            model = ENet(n_classes=dst.n_classes)
        elif args.structure == 'ENetV2':
            model = ENetV2(n_classes=dst.n_classes)
        elif args.structure == 'drn_d_22':
            model = DRNSeg(model_name='drn_d_22',
                           n_classes=dst.n_classes,
                           pretrained=args.init_vgg16)
        elif args.structure == 'drn_a_50':
            model = DRNSeg(model_name='drn_a_50',
                           n_classes=dst.n_classes,
                           pretrained=args.init_vgg16)
        elif args.structure == 'drn_a_18':
            model = DRNSeg(model_name='drn_a_18',
                           n_classes=dst.n_classes,
                           pretrained=args.init_vgg16)
        elif args.structure == 'drn_e_22':
            model = DRNSeg(model_name='drn_e_22',
                           n_classes=dst.n_classes,
                           pretrained=args.init_vgg16)
        elif args.structure == 'erfnet':
            model = erfnet(n_classes=dst.n_classes)
        elif args.structure == 'fcdensenet103':
            model = fcdensenet103(n_classes=dst.n_classes)
        elif args.structure == 'fcdensenet56':
            model = fcdensenet56(n_classes=dst.n_classes)
        elif args.structure == 'fcdensenet_tiny':
            model = fcdensenet_tiny(n_classes=dst.n_classes)
        elif args.structure == 'Res_Deeplab_101':
            model = Res_Deeplab_101(n_classes=dst.n_classes)
        elif args.structure == 'Res_Deeplab_50':
            model = Res_Deeplab_50(n_classes=dst.n_classes)
        elif args.structure == 'EDANet':
            model = EDANet(n_classes=dst.n_classes)
        elif args.structure == 'drn_a_asymmetric_18':
            model = DRNSeg(model_name='drn_a_asymmetric_18',
                           n_classes=dst.n_classes,
                           pretrained=False)
        if args.validate_model_state_dict != '':
            try:
                model.load_state_dict(
                    torch.load(args.validate_model_state_dict))
            except KeyError:
                print('missing key')
    if args.cuda:
        model.cuda()
    model.eval()

    gts, preds = [], []
    for i, (imgs, labels) in enumerate(valloader):
        print(i)
        #  print(labels.shape)
        #  print(imgs.shape)
        # 将np变量转换为pytorch中的变量
        imgs = Variable(imgs)
        labels = Variable(labels)

        if args.cuda:
            imgs = imgs.cuda()
            labels = labels.cuda()

        outputs = model(imgs)
        # 取axis=1中的最大值,outputs的shape为batch_size*n_classes*height*width,
        # 获取max后,返回两个数组,分别是最大值和相应的索引值,这里取索引值为label
        pred = outputs.cpu().data.max(1)[1].numpy()
        gt = labels.cpu().data.numpy()
        # print(pred.dtype)
        # print(gt.dtype)
        # print('pred.shape:', pred.shape)
        # print('gt.shape:', gt.shape)

        # if args.vis and i % 1 == 0:
        #     img = imgs.cpu().data.numpy()[0]
        #     # print(img.shape)
        #     label_color = dst.decode_segmap(gt[0]).transpose(2, 0, 1)
        #     # print(label_color.shape)
        #     pred_label_color = dst.decode_segmap(pred[0]).transpose(2, 0, 1)
        #     # print(pred_label_color.shape)
        #     # try:
        #     #     win = 'label_color'
        #     #     vis.image(label_color, win=win)
        #     #     win = 'pred_label_color'
        #     #     vis.image(pred_label_color, win=win)
        #     # except ConnectionError:
        #     #     print('ConnectionError')
        #
        #
        #     if args.blend:
        #         img_hwc = img.transpose(1, 2, 0)
        #         img_hwc = img_hwc*255.0
        #         img_hwc += np.array([104.00699, 116.66877, 122.67892])
        #         img_hwc = np.array(img_hwc, dtype=np.uint8)
        #         # label_color_hwc = label_color.transpose(1, 2, 0)
        #         pred_label_color_hwc = pred_label_color.transpose(1, 2, 0)
        #         pred_label_color_hwc = np.array(pred_label_color_hwc, dtype=np.uint8)
        #         # print(img_hwc.dtype)
        #         # print(pred_label_color_hwc.dtype)
        #         label_blend = img_hwc * 0.5 + pred_label_color_hwc * 0.5
        #         label_blend = np.array(label_blend, dtype=np.uint8)
        #
        #         if not os.path.exists('/tmp/' + init_time):
        #             os.mkdir('/tmp/' + init_time)
        #         time_str = str(int(time.time()))
        #
        #         misc.imsave('/tmp/'+init_time+'/'+time_str+'_label_blend.png', label_blend)

        for gt_, pred_ in zip(gt, pred):
            gts.append(gt_)
            preds.append(pred_)

    score, class_iou = scores(gts, preds, n_class=dst.n_classes)
    for k, v in score.items():
        print(k, v)

    for i in range(dst.n_classes):
        print(i, class_iou[i])
Exemple #3
0
def train(args):
    init_time = str(int(time.time()))
    if args.vis:
        vis = visdom.Visdom()
    if args.dataset_path == '':
        HOME_PATH = os.path.expanduser('~')
        local_path = os.path.join(HOME_PATH, 'Data/CamVid')
    else:
        local_path = args.dataset_path
    dst = camvidLoader(local_path, is_transform=True, is_augment=args.data_augment)
    dst.n_classes = args.n_classes # 保证输入的class
    trainloader = torch.utils.data.DataLoader(dst, batch_size=args.batch_size, shuffle=True)

    start_epoch = 0
    if args.resume_model != '':
        model = torch.load(args.resume_model)
        start_epoch_id1 = args.resume_model.rfind('_')
        start_epoch_id2 = args.resume_model.rfind('.')
        start_epoch = int(args.resume_model[start_epoch_id1+1:start_epoch_id2])
    else:
        if args.structure == 'fcn32s':
            model = fcn(module_type='32s', n_classes=dst.n_classes, pretrained=args.init_vgg16)
        elif args.structure == 'fcn16s':
            model = fcn(module_type='16s', n_classes=dst.n_classes, pretrained=args.init_vgg16)
        elif args.structure == 'fcn8s':
            model = fcn(module_type='8s', n_classes=dst.n_classes, pretrained=args.init_vgg16)
        elif args.structure == 'ResNetDUC':
            model = ResNetDUC(n_classes=dst.n_classes, pretrained=args.init_vgg16)
        elif args.structure == 'segnet':
            model = segnet(n_classes=dst.n_classes, pretrained=args.init_vgg16)
        elif args.structure == 'ENet':
            model = ENet(n_classes=dst.n_classes, pretrained=args.init_vgg16)
        elif args.structure == 'drn_d_22':
            model = DRNSeg(model_name='drn_d_22', n_classes=dst.n_classes, pretrained=args.init_vgg16)
        elif args.structure == 'pspnet':
            model = pspnet(n_classes=dst.n_classes, pretrained=args.init_vgg16, use_aux=False)
        elif args.structure == 'erfnet':
            model = erfnet(n_classes=dst.n_classes)
        if args.resume_model_state_dict != '':
            try:
                # fcn32s、fcn16s和fcn8s模型略有增加参数,互相赋值重新训练过程中会有KeyError,暂时捕捉异常处理
                start_epoch_id1 = args.resume_model_state_dict.rfind('_')
                start_epoch_id2 = args.resume_model_state_dict.rfind('.')
                start_epoch = int(args.resume_model_state_dict[start_epoch_id1 + 1:start_epoch_id2])
                pretrained_dict = torch.load(args.resume_model_state_dict)
                # model_dict = model.state_dict()
                # for k, v in pretrained_dict.items():
                #     print(k)
                # pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
                # model_dict.update(pretrained_dict)
                model.load_state_dict(pretrained_dict)
            except KeyError:
                print('missing key')



    if args.cuda:
        model.cuda()
    print('start_epoch:', start_epoch)
    optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, momentum=0.99, weight_decay=5e-4)
    for epoch in range(start_epoch+1, 20000, 1):
        loss_epoch = 0
        loss_avg_epoch = 0
        data_count = 0
        # if args.vis:
        #     vis.text('epoch:{}'.format(epoch), win='epoch')
        for i, (imgs, labels) in enumerate(trainloader):
            print(i)
            data_count = i
            # print(labels.shape)
            # print(imgs.shape)

            imgs = Variable(imgs)
            labels = Variable(labels)

            if args.cuda:
                imgs = imgs.cuda()
                labels = labels.cuda()
            outputs = model(imgs)

            if args.vis and i%50==0:
                pred_labels = outputs.data.max(1)[1].numpy()
                # print(pred_labels.shape)
                label_color = dst.decode_segmap(labels.data.numpy()[0]).transpose(2, 0, 1)
                # print(label_color.shape)
                pred_label_color = dst.decode_segmap(pred_labels[0]).transpose(2, 0, 1)
                # print(pred_label_color.shape)
                win = 'label_color'
                vis.image(label_color, win=win)
                win = 'pred_label_color'
                vis.image(pred_label_color, win=win)

                if epoch < 100:
                    if not os.path.exists('/tmp/'+init_time):
                        os.mkdir('/tmp/'+init_time)
                    time_str = str(int(time.time()))
                    print('label_color.transpose(2, 0, 1).shape:', label_color.transpose(1, 2, 0).shape)
                    print('pred_label_color.transpose(2, 0, 1).shape:', pred_label_color.transpose(1, 2, 0).shape)
                    cv2.imwrite('/tmp/'+init_time+'/'+time_str+'_label.png', label_color.transpose(1, 2, 0))
                    cv2.imwrite('/tmp/'+init_time+'/'+time_str+'_pred_label.png', pred_label_color.transpose(1, 2, 0))


            # print(outputs.size())
            # print(labels.size())
            # 一次backward后如果不清零,梯度是累加的
            optimizer.zero_grad()

            loss = cross_entropy2d(outputs, labels)
            loss_numpy = loss.cpu().data.numpy()
            loss_epoch += loss_numpy
            print('loss:', loss_numpy)
            loss.backward()

            optimizer.step()

            # 显示一个周期的loss曲线
            if args.vis:
                win = 'loss'
                win_res = vis.line(X=np.ones(1)*i, Y=loss.data.numpy(), win=win, update='append')
                if win_res != win:
                    vis.line(X=np.ones(1)*i, Y=loss.data.numpy(), win=win)

        # 关闭清空一个周期的loss
        if args.vis:
            win = 'loss'
            vis.close(win)

        # 显示多个周期的loss曲线
        loss_avg_epoch = loss_epoch / (data_count * 1.0)
        # print(loss_avg_epoch)
        if args.vis:
            win = 'loss_epoch'
            win_res = vis.line(X=np.ones(1)*epoch, Y=loss_avg_epoch, win=win, update='append')
            if win_res != win:
                vis.line(X=np.ones(1)*epoch, Y=loss_avg_epoch, win=win)

        if args.save_model and epoch%args.save_epoch==0:
            torch.save(model.state_dict(), '{}_camvid_class_{}_{}.pt'.format(args.structure, dst.n_classes, epoch))