コード例 #1
0
ファイル: validate.py プロジェクト: eglrp/semseg
def validate(args):
    init_time = str(int(time.time()))
    if args.vis:
        vis = visdom.Visdom()
    if args.dataset_path == '':
        HOME_PATH = os.path.expanduser('~')
        local_path = os.path.join(HOME_PATH, 'Data/CamVid')
    else:
        local_path = args.dataset_path
    local_path = os.path.expanduser(args.dataset_path)
    if args.dataset == 'CamVid':
        dst = camvidLoader(local_path, is_transform=True, split=args.dataset_type)
    elif args.dataset == 'CityScapes':
        dst = cityscapesLoader(local_path, is_transform=True, split=args.dataset_type)
    else:
        pass
    val_loader = torch.utils.data.DataLoader(dst, batch_size=1, shuffle=False)

    # if os.path.isfile(args.validate_model):
    if args.validate_model != '':
        model = torch.load(args.validate_model)
    else:
        try:
            model = eval(args.structure)(n_classes=args.n_classes, pretrained=args.init_vgg16)
        except:
            print('missing structure or not support')
            exit(0)
        if args.validate_model_state_dict != '':
            try:
                model.load_state_dict(torch.load(args.validate_model_state_dict, map_location='cpu'))
            except KeyError:
                print('missing key')
    if args.cuda:
        model.cuda()
    model.eval()

    gts, preds, errors, imgs_name = [], [], [], []
    for i, (imgs, labels) in enumerate(val_loader):
        print(i)
        if i==1:
            break
        img_path = dst.files[args.dataset_type][i]
        img_name = img_path[img_path.rfind('/')+1:]
        imgs_name.append(img_name)
        # print('img_path:', img_path)
        # print('img_name:', img_name)
        #  print(labels.shape)
        #  print(imgs.shape)
        # 将np变量转换为pytorch中的变量
        imgs = Variable(imgs)
        labels = Variable(labels)

        if args.cuda:
            imgs = imgs.cuda()
            labels = labels.cuda()

        outputs = model(imgs)
        loss = cross_entropy2d(outputs, labels)
        loss_np = loss.cpu().data.numpy()
        loss_np_float = float(loss_np)

        # print('loss_np_float:', loss_np_float)
        errors.append(loss_np_float)

        # 取axis=1中的最大值,outputs的shape为batch_size*n_classes*height*width,
        # 获取max后,返回两个数组,分别是最大值和相应的索引值,这里取索引值为label
        pred = outputs.cpu().data.max(1)[1].numpy()
        gt = labels.cpu().data.numpy()

        if args.save_result:
            if not os.path.exists('/tmp/'+init_time):
                os.mkdir('/tmp/'+init_time)
            pred_labels = outputs.cpu().data.max(1)[1].numpy()
            label_color = dst.decode_segmap(labels.cpu().data.numpy()[0]).transpose(2, 0, 1)
            pred_label_color = dst.decode_segmap(pred_labels[0]).transpose(2, 0, 1)

            # label_color_cv2 = label_color.transpose(1, 2, 0)
            # label_color_cv2 = cv2.cvtColor(label_color_cv2, cv2.COLOR_RGB2BGR)
            # cv2.imwrite('/tmp/'+init_time+'/{}'.format(img_name), label_color_cv2)

            pred_label_color_cv2 = pred_label_color.transpose(1, 2, 0)
            pred_label_color_cv2 = cv2.cvtColor(pred_label_color_cv2, cv2.COLOR_RGB2BGR)
            cv2.imwrite('/tmp/'+init_time+'/{}'.format(img_name), pred_label_color_cv2)

        for gt_, pred_ in zip(gt, pred):
            gts.append(gt_)
            preds.append(pred_)

    # print('errors:', errors)
    # print('imgs_name:', imgs_name)

    errors_indices = np.argsort(errors).tolist()
    # print('errors_indices:', errors_indices)
    # for top_i in range(len(errors_indices)):
    for top_i in range(10):
        top_index = errors_indices.index(top_i)
        # print('top_index:', top_index)
        img_name_top = imgs_name[top_index]
        print('img_name_top:', img_name_top)

    score, class_iou = scores(gts, preds, n_class=dst.n_classes)
    for k, v in score.items():
        print(k, v)

    for i in range(dst.n_classes):
        print(i, class_iou[i])
コード例 #2
0
ファイル: train_mt.py プロジェクト: qxdaaaaa/semseg-1
def train(args):
    def type_callback(event):
        # print('event_type:{}'.format(event['event_type']))
        if event['event_type'] == 'KeyPress':
            event_key = event['key']
            if event_key == 'Enter':
                pass
                # print('event_type:Enter')
            elif event_key == 'Backspace':
                pass
                # print('event_type:Backspace')
            elif event_key == 'Delete':
                pass
                # print('event_type:Delete')
            elif len(event_key) == 1:
                pass
                # print('event_key:{}'.format(event['key']))
                if event_key=='s':
                    import json
                    win = 'loss_iteration'
                    win_data = vis.get_window_data(win)
                    win_data_dict = json.loads(win_data)
                    win_data_content_dict = win_data_dict['content']
                    win_data_x = np.array(win_data_content_dict['data'][0]['x'])
                    win_data_y = np.array(win_data_content_dict['data'][0]['y'])

                    win_data_save_file = '/tmp/loss_iteration_{}.txt'.format(init_time)
                    with open(win_data_save_file, 'wb') as f:
                        for item_x, item_y in zip(win_data_x, win_data_y):
                            f.write("{} {}\n".format(item_x, item_y))
                    done_time = str(int(time.time()))
                    vis.text(vis_text_usage+'done at {}'.format(done_time), win=callback_text_usage_window)

    init_time = str(int(time.time()))
    if args.vis:
        # start visdom and close all window
        vis = visdom.Visdom()
        vis.close()

        vis_text_usage = 'Operating in the text window<br>Press s to save data<br>'
        callback_text_usage_window = vis.text(vis_text_usage)
        vis.register_event_handler(type_callback, callback_text_usage_window)

    class_weight = None
    local_path = os.path.expanduser(args.dataset_path)
    train_dst = None
    val_dst = None
    if args.dataset == 'CamVid':
        train_dst = camvidLoader(local_path, is_transform=True, is_augment=args.data_augment, split='train')
        val_dst = camvidLoader(local_path, is_transform=True, is_augment=False, split='val')

        trainannot_image_dir = os.path.expanduser(os.path.join(local_path, "trainannot"))
        trainannot_image_files = [os.path.join(trainannot_image_dir, file) for file in os.listdir(trainannot_image_dir) if file.endswith('.png')]
        if args.class_weighting=='MFB':
            class_weight = median_frequency_balancing(trainannot_image_files, num_classes=12)
            class_weight = torch.tensor(class_weight)
        elif args.class_weighting=='ENET':
            class_weight = ENet_weighing(trainannot_image_files, num_classes=12)
            class_weight = torch.tensor(class_weight)

    elif args.dataset == 'CityScapes':
        train_dst = cityscapesLoader(local_path, is_transform=True, split='train')
        val_dst = cityscapesLoader(local_path, is_transform=True, split='val')
    else:
        print('{} dataset does not implement'.format(args.dataset))
        exit(0)

    if args.cuda:
        if class_weight is not None:
            class_weight = class_weight.cuda()
    print('class_weight:', class_weight)

    train_loader = torch.utils.data.DataLoader(train_dst, batch_size=args.batch_size, shuffle=True)
    val_loader = torch.utils.data.DataLoader(val_dst, batch_size=1, shuffle=True)

    yolo_B = 2
    yolo_C = 2
    yolo_S = 7
    yolo_out_tensor_shape = yolo_B * 5 + yolo_C
    print('yolo_out_tensor_shape:', yolo_out_tensor_shape)
    det_criterion = yoloLoss(yolo_S, yolo_B, yolo_C, 5, 0.5, args.cuda)

    det_file_root = os.path.expanduser('~/Data/CamVid/train/')
    det_train_dst = yoloDataset(root=det_file_root, list_file=['camvid_det.txt'], train=True, transform=[transforms.ToTensor()], yolo_out_tensor_shape=yolo_out_tensor_shape)
    det_train_loader = torch.utils.data.DataLoader(det_train_dst, batch_size=1, shuffle=True, num_workers=4)

    start_epoch = 0
    best_mIoU = 0
    if args.resume_model != '':
        model = torch.load(args.resume_model)
        start_epoch_id1 = args.resume_model.rfind('_')
        start_epoch_id2 = args.resume_model.rfind('.')
        start_epoch = int(args.resume_model[start_epoch_id1+1:start_epoch_id2])
    else:
        model = drnsegmt_a_18(pretrained=args.init_vgg16, n_classes=args.n_classes, det_tensor_num=yolo_out_tensor_shape)
        if args.resume_model_state_dict != '':
            try:
                # from model save format get useful information, such as miou, epoch
                miou_model_name_str = '_miou_'
                class_model_name_str = '_class_'
                miou_id1 = args.resume_model_state_dict.find(miou_model_name_str)+len(miou_model_name_str)
                miou_id2 = args.resume_model_state_dict.find(class_model_name_str)
                best_mIoU = float(args.resume_model_state_dict[miou_id1:miou_id2])

                start_epoch_id1 = args.resume_model_state_dict.rfind('_')
                start_epoch_id2 = args.resume_model_state_dict.rfind('.')
                start_epoch = int(args.resume_model_state_dict[start_epoch_id1 + 1:start_epoch_id2])
                pretrained_dict = torch.load(args.resume_model_state_dict, map_location='cpu')
                model.load_state_dict(pretrained_dict)
            except KeyError:
                print('missing resume_model_state_dict or wrong type')



    if args.cuda:
        model.cuda()
    print('start_epoch:', start_epoch)
    print('best_mIoU:', best_mIoU)

    optimizer = None
    if args.solver == 'SGD':
        optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, momentum=0.99, weight_decay=5e-4)
    elif args.solver == 'RMSprop':
        optimizer = torch.optim.RMSprop(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, momentum=0.99, weight_decay=5e-4)
    elif args.solver == 'Adam':
        optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=5e-4)
    else:
        print('missing solver or not support')
        exit(0)
    # when observerd object dose not decrease scheduler will let the optimizer learing rate decrease
    # scheduler = ReduceLROnPlateau(optimizer, 'min', patience=100, min_lr=1e-10, verbose=True)
    scheduler = None
    if args.lr_policy == 'Constant':
        scheduler = ConstantLR(optimizer)
    elif args.lr_policy == 'Polynomial':
        scheduler = PolynomialLR(optimizer, max_iter=args.training_epoch, power=0.9) # base lr=0.01 power=0.9 like PSPNet

    # scheduler = StepLR(optimizer, step_size=1, gamma=0.1)

    data_count = int(train_dst.__len__() * 1.0 / args.batch_size)
    det_data_count = int(det_train_dst.__len__() * 1.0 / 1)
    print('data_count:', data_count)
    # iteration_step = 0
    train_gts, train_preds = [], []
    for epoch in range(start_epoch+1, args.training_epoch, 1):
        loss_epoch = 0
        scheduler.step()

        # ----for object detection----
        for det_i, (det_imgs, det_labels, _) in enumerate(det_train_loader):
            model.train()
            # print('det_imgs.shape:', det_imgs.shape)
            # print('det_labels.shape:', det_labels.shape)

            det_imgs = Variable(det_imgs)
            det_labels = Variable(det_labels)

            if args.cuda:
                det_imgs = det_imgs.cuda()
                det_labels = det_labels.cuda()

            _, outputs_det = model(det_imgs)
            # print('outpust_det:', outputs_det.shape)

            det_loss = det_criterion(outputs_det, det_labels)
            det_loss = 0.02 * det_loss # for balance with segment and detection
            det_loss_np = det_loss.cpu().data.numpy()
            optimizer.zero_grad()
            det_loss.backward()
            optimizer.step()

            # 显示一个周期的loss曲线
            if args.vis:
                win = 'det_loss_iteration'
                det_loss_np_expand = np.expand_dims(det_loss_np, axis=0)
                win_res = vis.line(X=np.ones(1)*(det_i+det_data_count*(epoch-1)+1), Y=det_loss_np_expand, win=win, update='append')
                if win_res != win:
                    vis.line(X=np.ones(1)*(det_i+det_data_count*(epoch-1)+1), Y=det_loss_np_expand, win=win, opts=dict(title=win, xlabel='iteration', ylabel='loss'))
        # ----for object detection----

        # ----for semantic segment----
        for i, (imgs, labels) in enumerate(train_loader):
            # if i==1:
            #     break
            model.train()

            # 最后的几张图片可能不到batch_size的数量,比如batch_size=4,可能只剩3张
            imgs_batch = imgs.shape[0]
            if imgs_batch != args.batch_size:
                break
            # iteration_step += 1

            imgs = Variable(imgs)
            labels = Variable(labels)

            if args.cuda:
                imgs = imgs.cuda()
                labels = labels.cuda()
            outputs_sem, _ = model(imgs)
            # print('outputs_sem.shape:', outputs_sem.shape)

            # 一次backward后如果不清零,梯度是累加的
            optimizer.zero_grad()

            # print('outputs.size:', outputs.size())
            # print('labels.size:', labels.size())

            loss = cross_entropy2d(outputs_sem, labels, weight=class_weight)
            loss_np = loss.cpu().data.numpy()
            loss_epoch += loss_np
            loss.backward()

            optimizer.step()

            # ------------------train metris-------------------------------
            train_pred = outputs_sem.cpu().data.max(1)[1].numpy()
            train_gt = labels.cpu().data.numpy()

            for train_gt_, train_pred_ in zip(train_gt, train_pred):
                train_gts.append(train_gt_)
                train_preds.append(train_pred_)
            # ------------------train metris-------------------------------

            if args.vis and i%50==0:
                pred_labels = outputs_sem.cpu().data.max(1)[1].numpy()
                label_color = train_dst.decode_segmap(labels.cpu().data.numpy()[0]).transpose(2, 0, 1)
                pred_label_color = train_dst.decode_segmap(pred_labels[0]).transpose(2, 0, 1)
                win = 'label_color'
                vis.image(label_color, win=win, opts=dict(title='Gt', caption='Ground Truth'))
                win = 'pred_label_color'
                vis.image(pred_label_color, win=win, opts=dict(title='Pred', caption='Prediction'))

            # 显示一个周期的loss曲线
            if args.vis:
                win = 'loss_iteration'
                loss_np_expand = np.expand_dims(loss_np, axis=0)
                win_res = vis.line(X=np.ones(1)*(i+data_count*(epoch-1)+1), Y=loss_np_expand, win=win, update='append')
                if win_res != win:
                    vis.line(X=np.ones(1)*(i+data_count*(epoch-1)+1), Y=loss_np_expand, win=win, opts=dict(title=win, xlabel='iteration', ylabel='loss'))
        # ----for semantic segment----

        # val result on val dataset and pick best to save
        if args.val_interval > 0  and epoch % args.val_interval == 0:
            print('----starting val----')
            model.eval()

            val_gts, val_preds = [], []
            for val_i, (val_imgs, val_labels) in enumerate(val_loader):
                # print(val_i)
                val_imgs = Variable(val_imgs)
                val_labels = Variable(val_labels)

                if args.cuda:
                    val_imgs = val_imgs.cuda()
                    val_labels = val_labels.cuda()

                val_outputs_sem, _ = model(val_imgs)
                val_pred = val_outputs_sem.cpu().data.max(1)[1].numpy()
                val_gt = val_labels.cpu().data.numpy()
                for val_gt_, val_pred_ in zip(val_gt, val_pred):
                    val_gts.append(val_gt_)
                    val_preds.append(val_pred_)

            score, class_iou = scores(val_gts, val_preds, n_class=args.n_classes)
            for k, v in score.items():
                print(k, v)
                if k == 'Mean IoU : \t':
                    v_iou = v
                    if v > best_mIoU:
                        best_mIoU = v_iou
                        torch.save(model.state_dict(), '{}_{}_miou_{}_class_{}_{}.pt'.format(args.structure, args.dataset, best_mIoU, args.n_classes, epoch))
                    # 显示校准周期的mIoU
                    if args.vis:
                        win = 'mIoU_epoch'
                        v_iou_expand = np.expand_dims(v_iou, axis=0)
                        win_res = vis.line(X=np.ones(1)*epoch*args.val_interval, Y=v_iou_expand, win=win, update='append')
                        if win_res != win:
                            vis.line(X=np.ones(1)*epoch*args.val_interval, Y=v_iou_expand, win=win, opts=dict(title=win, xlabel='epoch', ylabel='mIoU'))

            # for class_i in range(args.n_classes):
            #     print(class_i, class_iou[class_i])
            print('----ending   val----')

        # 显示多个周期的loss曲线
        loss_avg_epoch = loss_epoch / (data_count * 1.0)
        # print(loss_avg_epoch)
        if args.vis:
            win = 'loss_epoch'
            loss_avg_epoch_expand = np.expand_dims(loss_avg_epoch, axis=0)
            win_res = vis.line(X=np.ones(1)*epoch, Y=loss_avg_epoch_expand, win=win, update='append')
            if win_res != win:
                vis.line(X=np.ones(1)*epoch, Y=loss_avg_epoch_expand, win=win, opts=dict(title=win, xlabel='epoch', ylabel='loss'))

        if args.vis:
            win = 'lr_epoch'
            lr_epoch = np.array(scheduler.get_lr())
            lr_epoch_expand = np.expand_dims(lr_epoch, axis=0)
            win_res = vis.line(X=np.ones(1)*epoch, Y=lr_epoch_expand, win=win, update='append')
            if win_res != win:
                vis.line(X=np.ones(1)*epoch, Y=lr_epoch_expand, win=win, opts=dict(title=win, xlabel='epoch', ylabel='lr'))

        # ------------------train metris-------------------------------
        if args.vis:
            score, class_iou = scores(train_gts, train_preds, n_class=args.n_classes)
            for k, v in score.items():
                print(k, v)
                if k == 'Overall Acc : \t':
                    # 显示校准周期的mIoU
                    overall_acc = v
                    if args.vis:
                        win = 'acc_epoch'
                        overall_acc_expand = np.expand_dims(overall_acc, axis=0)
                        win_res = vis.line(X=np.ones(1) * epoch, Y=overall_acc_expand, win=win,
                                           update='append')
                        if win_res != win:
                            vis.line(X=np.ones(1) * epoch, Y=overall_acc_expand, win=win,
                                     opts=dict(title=win, xlabel='epoch', ylabel='accuracy'))
            # clear for new training metrics
            train_gts, train_preds = [], []
        # ------------------train metris-------------------------------

        if args.save_model and epoch%args.save_epoch==0:
            torch.save(model.state_dict(), '{}_{}_class_{}_{}.pt'.format(args.structure, args.dataset, args.n_classes, epoch))
コード例 #3
0
def validate(args):
    init_time = str(int(time.time()))
    if args.vis:
        vis = visdom.Visdom()
    if args.dataset_path == '':
        HOME_PATH = os.path.expanduser('~')
        local_path = os.path.join(HOME_PATH, 'Data/CamVid')
    else:
        local_path = args.dataset_path
    local_path = os.path.expanduser(args.dataset_path)
    if args.dataset == 'CamVid':
        dst = camvidLoader(local_path,
                           is_transform=True,
                           split=args.dataset_type)
    elif args.dataset == 'CityScapes':
        dst = cityscapesLoader(local_path, is_transform=True)
    else:
        pass
    valloader = torch.utils.data.DataLoader(dst, batch_size=1)

    # if os.path.isfile(args.validate_model):
    if args.validate_model != '':
        model = torch.load(args.validate_model)
    else:
        if args.structure == 'fcn32s':
            model = fcn(module_type='32s',
                        n_classes=dst.n_classes,
                        pretrained=args.init_vgg16)
        elif args.structure == 'fcn16s':
            model = fcn(module_type='16s',
                        n_classes=dst.n_classes,
                        pretrained=args.init_vgg16)
        elif args.structure == 'fcn8s':
            model = fcn(module_type='8s',
                        n_classes=dst.n_classes,
                        pretrained=args.init_vgg16)
        elif args.structure == 'fcn_resnet18_32s':
            model = fcn_resnet18(module_type='32s',
                                 n_classes=dst.n_classes,
                                 pretrained=args.init_vgg16)
        elif args.structure == 'fcn_resnet18_16s':
            model = fcn_resnet18(module_type='16s',
                                 n_classes=dst.n_classes,
                                 pretrained=args.init_vgg16)
        elif args.structure == 'fcn_resnet18_8s':
            model = fcn_resnet18(module_type='8s',
                                 n_classes=dst.n_classes,
                                 pretrained=args.init_vgg16)
        elif args.structure == 'fcn_resnet34_32s':
            model = fcn_resnet34(module_type='32s',
                                 n_classes=dst.n_classes,
                                 pretrained=args.init_vgg16)
        elif args.structure == 'fcn_resnet34_16s':
            model = fcn_resnet34(module_type='16s',
                                 n_classes=dst.n_classes,
                                 pretrained=args.init_vgg16)
        elif args.structure == 'fcn_resnet34_8s':
            model = fcn_resnet34(module_type='8s',
                                 n_classes=dst.n_classes,
                                 pretrained=args.init_vgg16)
        elif args.structure == 'fcn_MobileNet_32s':
            model = fcn_MobileNet(module_type='32s',
                                  n_classes=dst.n_classes,
                                  pretrained=args.init_vgg16)
        elif args.structure == 'fcn_MobileNet_16s':
            model = fcn_MobileNet(module_type='16s',
                                  n_classes=dst.n_classes,
                                  pretrained=args.init_vgg16)
        elif args.structure == 'fcn_MobileNet_8s':
            model = fcn_MobileNet(module_type='8s',
                                  n_classes=dst.n_classes,
                                  pretrained=args.init_vgg16)
        elif args.structure == 'ResNetDUC':
            model = ResNetDUC(n_classes=dst.n_classes,
                              pretrained=args.init_vgg16)
        elif args.structure == 'ResNetDUCHDC':
            model = ResNetDUCHDC(n_classes=dst.n_classes,
                                 pretrained=args.init_vgg16)
        elif args.structure == 'segnet':
            model = segnet(n_classes=dst.n_classes, pretrained=args.init_vgg16)
        elif args.structure == 'segnet_vgg19':
            model = segnet_vgg19(n_classes=dst.n_classes,
                                 pretrained=args.init_vgg16)
        elif args.structure == 'segnet_unet':
            model = segnet_unet(n_classes=dst.n_classes,
                                pretrained=args.init_vgg16)
        elif args.structure == 'segnet_alignres':
            model = segnet_alignres(n_classes=dst.n_classes,
                                    pretrained=args.init_vgg16)
        elif args.structure == 'sqnet':
            model = sqnet(n_classes=dst.n_classes, pretrained=args.init_vgg16)
        elif args.structure == 'segnet_squeeze':
            model = segnet_squeeze(n_classes=dst.n_classes,
                                   pretrained=args.init_vgg16)
        elif args.structure == 'ENet':
            model = ENet(n_classes=dst.n_classes)
        elif args.structure == 'ENetV2':
            model = ENetV2(n_classes=dst.n_classes)
        elif args.structure == 'drn_d_22':
            model = DRNSeg(model_name='drn_d_22',
                           n_classes=dst.n_classes,
                           pretrained=args.init_vgg16)
        elif args.structure == 'drn_a_50':
            model = DRNSeg(model_name='drn_a_50',
                           n_classes=dst.n_classes,
                           pretrained=args.init_vgg16)
        elif args.structure == 'drn_a_18':
            model = DRNSeg(model_name='drn_a_18',
                           n_classes=dst.n_classes,
                           pretrained=args.init_vgg16)
        elif args.structure == 'drn_e_22':
            model = DRNSeg(model_name='drn_e_22',
                           n_classes=dst.n_classes,
                           pretrained=args.init_vgg16)
        elif args.structure == 'erfnet':
            model = erfnet(n_classes=dst.n_classes)
        elif args.structure == 'fcdensenet103':
            model = fcdensenet103(n_classes=dst.n_classes)
        elif args.structure == 'fcdensenet56':
            model = fcdensenet56(n_classes=dst.n_classes)
        elif args.structure == 'fcdensenet_tiny':
            model = fcdensenet_tiny(n_classes=dst.n_classes)
        elif args.structure == 'Res_Deeplab_101':
            model = Res_Deeplab_101(n_classes=dst.n_classes)
        elif args.structure == 'Res_Deeplab_50':
            model = Res_Deeplab_50(n_classes=dst.n_classes)
        elif args.structure == 'EDANet':
            model = EDANet(n_classes=dst.n_classes)
        elif args.structure == 'drn_a_asymmetric_18':
            model = DRNSeg(model_name='drn_a_asymmetric_18',
                           n_classes=dst.n_classes,
                           pretrained=False)
        if args.validate_model_state_dict != '':
            try:
                model.load_state_dict(
                    torch.load(args.validate_model_state_dict))
            except KeyError:
                print('missing key')
    if args.cuda:
        model.cuda()
    model.eval()

    gts, preds = [], []
    for i, (imgs, labels) in enumerate(valloader):
        print(i)
        #  print(labels.shape)
        #  print(imgs.shape)
        # 将np变量转换为pytorch中的变量
        imgs = Variable(imgs)
        labels = Variable(labels)

        if args.cuda:
            imgs = imgs.cuda()
            labels = labels.cuda()

        outputs = model(imgs)
        # 取axis=1中的最大值,outputs的shape为batch_size*n_classes*height*width,
        # 获取max后,返回两个数组,分别是最大值和相应的索引值,这里取索引值为label
        pred = outputs.cpu().data.max(1)[1].numpy()
        gt = labels.cpu().data.numpy()
        # print(pred.dtype)
        # print(gt.dtype)
        # print('pred.shape:', pred.shape)
        # print('gt.shape:', gt.shape)

        # if args.vis and i % 1 == 0:
        #     img = imgs.cpu().data.numpy()[0]
        #     # print(img.shape)
        #     label_color = dst.decode_segmap(gt[0]).transpose(2, 0, 1)
        #     # print(label_color.shape)
        #     pred_label_color = dst.decode_segmap(pred[0]).transpose(2, 0, 1)
        #     # print(pred_label_color.shape)
        #     # try:
        #     #     win = 'label_color'
        #     #     vis.image(label_color, win=win)
        #     #     win = 'pred_label_color'
        #     #     vis.image(pred_label_color, win=win)
        #     # except ConnectionError:
        #     #     print('ConnectionError')
        #
        #
        #     if args.blend:
        #         img_hwc = img.transpose(1, 2, 0)
        #         img_hwc = img_hwc*255.0
        #         img_hwc += np.array([104.00699, 116.66877, 122.67892])
        #         img_hwc = np.array(img_hwc, dtype=np.uint8)
        #         # label_color_hwc = label_color.transpose(1, 2, 0)
        #         pred_label_color_hwc = pred_label_color.transpose(1, 2, 0)
        #         pred_label_color_hwc = np.array(pred_label_color_hwc, dtype=np.uint8)
        #         # print(img_hwc.dtype)
        #         # print(pred_label_color_hwc.dtype)
        #         label_blend = img_hwc * 0.5 + pred_label_color_hwc * 0.5
        #         label_blend = np.array(label_blend, dtype=np.uint8)
        #
        #         if not os.path.exists('/tmp/' + init_time):
        #             os.mkdir('/tmp/' + init_time)
        #         time_str = str(int(time.time()))
        #
        #         misc.imsave('/tmp/'+init_time+'/'+time_str+'_label_blend.png', label_blend)

        for gt_, pred_ in zip(gt, pred):
            gts.append(gt_)
            preds.append(pred_)

    score, class_iou = scores(gts, preds, n_class=dst.n_classes)
    for k, v in score.items():
        print(k, v)

    for i in range(dst.n_classes):
        print(i, class_iou[i])
コード例 #4
0
ファイル: validate.py プロジェクト: USTCPCS/semseg
def validate(args):
    if args.vis:
        vis = visdom.Visdom()
    if args.dataset_path == '':
        HOME_PATH = os.path.expanduser('~')
        local_path = os.path.join(HOME_PATH, 'Data/CamVid')
    else:
        local_path = args.dataset_path
    dst = camvidLoader(local_path, is_transform=True, split='val')
    valloader = torch.utils.data.DataLoader(dst, batch_size=1)

    # if os.path.isfile(args.validate_model):
    if args.validate_model != '':
        model = torch.load(args.validate_model)
    else:
        if args.structure == 'fcn32s':
            model = fcn(module_type='32s', n_classes=dst.n_classes)
        elif args.structure == 'fcn16s':
            model = fcn(module_type='16s', n_classes=dst.n_classes)
        elif args.structure == 'fcn8s':
            model = fcn(module_type='8s', n_classes=dst.n_classes)
        elif args.structure == 'ResNetDUC':
            model = ResNetDUC(n_classes=dst.n_classes)
        elif args.structure == 'segnet':
            model = segnet(n_classes=dst.n_classes)
        elif args.structure == 'ENet':
            model = ENet(n_classes=dst.n_classes)
        elif args.structure == 'drn_d_22':
            model = DRNSeg(model_name='drn_d_22', n_classes=dst.n_classes)
        elif args.structure == 'pspnet':
            model = pspnet(n_classes=dst.n_classes, use_aux=False)
        elif args.structure == 'erfnet':
            model = erfnet(n_classes=dst.n_classes)
        if args.validate_model_state_dict != '':
            try:
                model.load_state_dict(torch.load(args.validate_model_state_dict))
            except KeyError:
                print('missing key')
    model.eval()

    gts, preds = [], []
    for i, (imgs, labels) in enumerate(valloader):
        print(i)
        #  print(labels.shape)
        #  print(imgs.shape)
        # 将np变量转换为pytorch中的变量
        imgs = Variable(imgs)
        labels = Variable(labels)

        outputs = model(imgs)
        # 取axis=1中的最大值,outputs的shape为batch_size*n_classes*height*width,
        # 获取max后,返回两个数组,分别是最大值和相应的索引值,这里取索引值为label
        pred = outputs.data.max(1)[1].numpy()
        gt = labels.data.numpy()
        # print(pred.dtype)
        # print(gt.dtype)
        # print('pred.shape:', pred.shape)
        # print('gt.shape:', gt.shape)

        if args.vis and i % 50 == 0:
            img = imgs.data.numpy()[0]
            # print(img.shape)
            label_color = dst.decode_segmap(gt[0]).transpose(2, 0, 1)
            # print(label_color.shape)
            pred_label_color = dst.decode_segmap(pred[0]).transpose(2, 0, 1)
            # print(pred_label_color.shape)
            # try:
            #     win = 'label_color'
            #     vis.image(label_color, win=win)
            #     win = 'pred_label_color'
            #     vis.image(pred_label_color, win=win)
            # except ConnectionError:
            #     print('ConnectionError')


            if args.blend:
                img_hwc = img.transpose(1, 2, 0)
                img_hwc = img_hwc*255.0
                img_hwc += np.array([104.00699, 116.66877, 122.67892])
                img_hwc = np.array(img_hwc, dtype=np.uint8)
                # label_color_hwc = label_color.transpose(1, 2, 0)
                pred_label_color_hwc = pred_label_color.transpose(1, 2, 0)
                pred_label_color_hwc = np.array(pred_label_color_hwc, dtype=np.uint8)
                # print(img_hwc.dtype)
                # print(pred_label_color_hwc.dtype)
                label_blend = img_hwc * 0.5 + pred_label_color_hwc * 0.5
                label_blend = np.array(label_blend, dtype=np.uint8)
                misc.imsave('/tmp/label_blend.png', label_blend)

        for gt_, pred_ in zip(gt, pred):
            gts.append(gt_)
            preds.append(pred_)


    score, class_iou = scores(gts, preds, n_class=dst.n_classes)
    for k, v in score.items():
        print(k, v)

    for i in range(dst.n_classes):
        print(i, class_iou[i])
コード例 #5
0
ファイル: train_pred.py プロジェクト: qxdaaaaa/semseg-1
def train(args):
    def type_callback(event):
        # print('event_type:{}'.format(event['event_type']))
        if event['event_type'] == 'KeyPress':
            event_key = event['key']
            if event_key == 'Enter':
                pass
                # print('event_type:Enter')
            elif event_key == 'Backspace':
                pass
                # print('event_type:Backspace')
            elif event_key == 'Delete':
                pass
                # print('event_type:Delete')
            elif len(event_key) == 1:
                pass
                # print('event_key:{}'.format(event['key']))
                if event_key == 's':
                    import json
                    win = 'loss_iteration'
                    win_data = vis.get_window_data(win)
                    win_data_dict = json.loads(win_data)
                    win_data_content_dict = win_data_dict['content']
                    win_data_x = np.array(
                        win_data_content_dict['data'][0]['x'])
                    win_data_y = np.array(
                        win_data_content_dict['data'][0]['y'])

                    win_data_save_file = '/tmp/loss_iteration_{}.txt'.format(
                        init_time)
                    with open(win_data_save_file, 'wb') as f:
                        for item_x, item_y in zip(win_data_x, win_data_y):
                            f.write("{} {}\n".format(item_x, item_y))
                    done_time = str(int(time.time()))
                    vis.text(vis_text_usage + 'done at {}'.format(done_time),
                             win=callback_text_usage_window)

    init_time = str(int(time.time()))
    if args.vis:
        # start visdom and close all window
        vis = visdom.Visdom()
        vis.close()

        vis_text_usage = 'Operating in the text window<br>Press s to save data<br>'
        callback_text_usage_window = vis.text(vis_text_usage)
        vis.register_event_handler(type_callback, callback_text_usage_window)

    class_weight = None
    local_path = os.path.expanduser(args.dataset_path)
    train_dst = None
    val_dst = None
    if args.dataset == 'SegmPred':
        input_channel = 19
        train_dst = segmpredLoader(local_path,
                                   is_transform=True,
                                   split='train')
        val_dst = segmpredLoader(local_path, is_transform=True, split='val')
    elif args.dataset == 'MovingMNIST':
        # class_weight = [0.1, 0.5]
        # class_weight = torch.tensor(class_weight)
        train_dst = movingmnistLoader(local_path,
                                      is_transform=True,
                                      split='train')
        val_dst = movingmnistLoader(local_path, is_transform=True, split='val')
    elif args.dataset == 'FreeSpacePred':
        input_channel = 1
        train_dst = freespacepredLoader(local_path,
                                        is_transform=True,
                                        split='train')
        val_dst = freespacepredLoader(local_path,
                                      is_transform=True,
                                      split='test')
    else:
        print('{} dataset does not implement'.format(args.dataset))
        exit(0)

    if args.cuda:
        if class_weight is not None:
            class_weight = class_weight.cuda()
    print('class_weight:', class_weight)

    train_loader = torch.utils.data.DataLoader(train_dst,
                                               batch_size=args.batch_size,
                                               shuffle=True)
    val_loader = torch.utils.data.DataLoader(val_dst,
                                             batch_size=1,
                                             shuffle=True)

    start_epoch = 0
    best_mIoU = 0
    if args.resume_model != '':
        model = torch.load(args.resume_model)
        start_epoch_id1 = args.resume_model.rfind('_')
        start_epoch_id2 = args.resume_model.rfind('.')
        start_epoch = int(args.resume_model[start_epoch_id1 +
                                            1:start_epoch_id2])
    else:
        # ---------------for testing SegmPred---------------
        model = eval(args.structure)(n_classes=args.n_classes,
                                     pretrained=args.init_vgg16,
                                     input_shape=train_dst.input_shape,
                                     input_channel=input_channel)
        val_model = eval(args.structure)(n_classes=args.n_classes,
                                         pretrained=args.init_vgg16,
                                         input_shape=val_dst.input_shape,
                                         input_channel=input_channel)
        # ---------------for testing SegmPred---------------

        if args.resume_model_state_dict != '':
            try:
                # from model save format get useful information, such as miou, epoch
                miou_model_name_str = '_miou_'
                class_model_name_str = '_class_'
                miou_id1 = args.resume_model_state_dict.find(
                    miou_model_name_str) + len(miou_model_name_str)
                miou_id2 = args.resume_model_state_dict.find(
                    class_model_name_str)
                best_mIoU = float(
                    args.resume_model_state_dict[miou_id1:miou_id2])

                start_epoch_id1 = args.resume_model_state_dict.rfind('_')
                start_epoch_id2 = args.resume_model_state_dict.rfind('.')
                start_epoch = int(
                    args.resume_model_state_dict[start_epoch_id1 +
                                                 1:start_epoch_id2])
                pretrained_dict = torch.load(args.resume_model_state_dict,
                                             map_location='cpu')
                model.load_state_dict(pretrained_dict)
            except KeyError:
                print('missing resume_model_state_dict or wrong type')

    if args.cuda:
        model.cuda()
        val_model.cuda()
    print('start_epoch:', start_epoch)
    print('best_mIoU:', best_mIoU)

    if args.solver == 'SGD':
        optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad,
                                           model.parameters()),
                                    lr=args.lr,
                                    momentum=0.99,
                                    weight_decay=5e-4)
    elif args.solver == 'RMSprop':
        optimizer = torch.optim.RMSprop(filter(lambda p: p.requires_grad,
                                               model.parameters()),
                                        lr=args.lr,
                                        momentum=0.99,
                                        weight_decay=5e-4)
    elif args.solver == 'Adam':
        optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                            model.parameters()),
                                     lr=args.lr,
                                     weight_decay=5e-4)
    else:
        print('missing solver or not support')
        exit(0)
    # when observerd object dose not decrease scheduler will let the optimizer learing rate decrease
    # scheduler = ReduceLROnPlateau(optimizer, 'min', patience=100, min_lr=1e-10, verbose=True)
    if args.lr_policy == 'Constant':
        scheduler = ConstantLR(optimizer)
    elif args.lr_policy == 'Polynomial':
        scheduler = PolynomialLR(
            optimizer, max_iter=args.training_epoch,
            power=0.9)  # base lr=0.01 power=0.9 like PSPNet
    elif args.lr_policy == 'MultiStep':
        scheduler = MultiStepLR(
            optimizer, milestones=[10, 50, 90],
            gamma=0.1)  # base lr=0.01 power=0.9 like PSPNet

    # scheduler = StepLR(optimizer, step_size=1, gamma=0.1)

    data_count = int(train_dst.__len__() * 1.0 / args.batch_size)
    print('data_count:', data_count)
    # iteration_step = 0
    train_gts, train_preds = [], []
    for epoch in range(start_epoch + 1, args.training_epoch, 1):
        loss_epoch = 0
        scheduler.step()

        optimizer.zero_grad(
        )  # when train next time zero all grad, just acc the grad when the epoch training
        for i, (imgs, labels) in enumerate(train_loader):
            # if i==1:
            #     break
            model.train()

            # 最后的几张图片可能不到batch_size的数量,比如batch_size=4,可能只剩3张
            imgs_batch = imgs.shape[0]
            if imgs_batch != args.batch_size:
                break
            # iteration_step += 1

            imgs = Variable(imgs)
            labels = Variable(labels)

            if args.cuda:
                imgs = imgs.cuda()
                labels = labels.cuda()
            outputs = model(imgs)
            # print('imgs.size:', imgs.size())
            # print('labels.size:', labels.size())
            # print('outputs.size:', outputs.size())

            loss = cross_entropy2d(outputs, labels, weight=class_weight)

            # add grad backward the avg loss
            loss_grad_acc_avg = loss * 1.0 / args.grad_acc_steps
            loss_grad_acc_avg.backward()

            loss_np = loss.cpu().data.numpy()
            loss_epoch += loss_np

            if (i + 1) % args.grad_acc_steps == 0:
                optimizer.step()
                # 一次backward后如果不清零,梯度是累加的
                optimizer.zero_grad()

            # ------------------train metris-------------------------------
            train_pred = outputs.cpu().data.max(1)[1].numpy()
            train_gt = labels.cpu().data.numpy()

            for train_gt_, train_pred_ in zip(train_gt, train_pred):
                train_gts.append(train_gt_)
                train_preds.append(train_pred_)
            # ------------------train metris-------------------------------

            if args.vis and i % 50 == 0:
                pred_labels = outputs.cpu().data.max(1)[1].numpy()
                label_color = train_dst.decode_segmap(
                    labels.cpu().data.numpy()[0]).transpose(2, 0, 1)
                pred_label_color = train_dst.decode_segmap(
                    pred_labels[0]).transpose(2, 0, 1)
                win = 'label_color'
                vis.image(label_color,
                          win=win,
                          opts=dict(title='Gt', caption='Ground Truth'))
                win = 'pred_label_color'
                vis.image(pred_label_color,
                          win=win,
                          opts=dict(title='Pred', caption='Prediction'))

            # 显示一个周期的loss曲线
            if args.vis:
                win = 'loss_iteration'
                loss_np_expand = np.expand_dims(loss_np, axis=0)
                win_res = vis.line(X=np.ones(1) * (i + data_count *
                                                   (epoch - 1) + 1),
                                   Y=loss_np_expand,
                                   win=win,
                                   update='append')
                if win_res != win:
                    vis.line(X=np.ones(1) * (i + data_count * (epoch - 1) + 1),
                             Y=loss_np_expand,
                             win=win,
                             opts=dict(title=win,
                                       xlabel='iteration',
                                       ylabel='loss'))

        # val result on val dataset and pick best to save
        if args.val_interval > 0 and epoch % args.val_interval == 0:
            print('----starting val----')
            # model.eval()
            val_model.load_state_dict(model.state_dict())
            val_model.eval()

            val_gts, val_preds = [], []
            for val_i, (val_imgs, val_labels) in enumerate(val_loader):
                # print(val_i)
                val_imgs = Variable(val_imgs, volatile=True)
                val_labels = Variable(val_labels, volatile=True)

                if args.cuda:
                    val_imgs = val_imgs.cuda()
                    val_labels = val_labels.cuda()

                # val_outputs = model(val_imgs)
                val_outputs = val_model(val_imgs)
                val_pred = val_outputs.cpu().data.max(1)[1].numpy()
                val_gt = val_labels.cpu().data.numpy()
                for val_gt_, val_pred_ in zip(val_gt, val_pred):
                    val_gts.append(val_gt_)
                    val_preds.append(val_pred_)

            score, class_iou = scores(val_gts,
                                      val_preds,
                                      n_class=args.n_classes)
            for k, v in score.items():
                print(k, v)
                if k == 'Mean IoU : \t':
                    v_iou = v
                    if v > best_mIoU:
                        best_mIoU = v_iou
                        torch.save(
                            model.state_dict(),
                            '{}_{}_miou_{}_class_{}_{}.pt'.format(
                                args.structure, args.dataset, best_mIoU,
                                args.n_classes, epoch))
                    # 显示校准周期的mIoU
                    if args.vis:
                        win = 'mIoU_epoch'
                        v_iou_expand = np.expand_dims(v_iou, axis=0)
                        win_res = vis.line(X=np.ones(1) * epoch *
                                           args.val_interval,
                                           Y=v_iou_expand,
                                           win=win,
                                           update='append')
                        if win_res != win:
                            vis.line(X=np.ones(1) * epoch * args.val_interval,
                                     Y=v_iou_expand,
                                     win=win,
                                     opts=dict(title=win,
                                               xlabel='epoch',
                                               ylabel='mIoU'))

            for class_i in range(args.n_classes):
                print(class_i, class_iou[class_i])
            print('----ending   val----')

        # 显示多个周期的loss曲线
        loss_avg_epoch = loss_epoch / (data_count * 1.0)
        # print(loss_avg_epoch)
        if args.vis:
            win = 'loss_epoch'
            loss_avg_epoch_expand = np.expand_dims(loss_avg_epoch, axis=0)
            win_res = vis.line(X=np.ones(1) * epoch,
                               Y=loss_avg_epoch_expand,
                               win=win,
                               update='append')
            if win_res != win:
                vis.line(X=np.ones(1) * epoch,
                         Y=loss_avg_epoch_expand,
                         win=win,
                         opts=dict(title=win, xlabel='epoch', ylabel='loss'))

        if args.vis:
            win = 'lr_epoch'
            lr_epoch = np.array(scheduler.get_lr())
            lr_epoch_expand = np.expand_dims(lr_epoch, axis=0)
            win_res = vis.line(X=np.ones(1) * epoch,
                               Y=lr_epoch_expand,
                               win=win,
                               update='append')
            if win_res != win:
                vis.line(X=np.ones(1) * epoch,
                         Y=lr_epoch_expand,
                         win=win,
                         opts=dict(title=win, xlabel='epoch', ylabel='lr'))

        # ------------------train metris-------------------------------
        if args.vis:
            score, class_iou = scores(train_gts,
                                      train_preds,
                                      n_class=args.n_classes)
            for k, v in score.items():
                print(k, v)
                if k == 'Overall Acc : \t':
                    # 显示校准周期的mIoU
                    overall_acc = v
                    if args.vis:
                        win = 'acc_epoch'
                        overall_acc_expand = np.expand_dims(overall_acc,
                                                            axis=0)
                        win_res = vis.line(X=np.ones(1) * epoch,
                                           Y=overall_acc_expand,
                                           win=win,
                                           update='append')
                        if win_res != win:
                            vis.line(X=np.ones(1) * epoch,
                                     Y=overall_acc_expand,
                                     win=win,
                                     opts=dict(title=win,
                                               xlabel='epoch',
                                               ylabel='accuracy'))
            # clear for new training metrics
            train_gts, train_preds = [], []
        # ------------------train metris-------------------------------

        if args.save_model and epoch % args.save_epoch == 0:
            torch.save(
                model.state_dict(),
                '{}_{}_class_{}_{}.pt'.format(args.structure, args.dataset,
                                              args.n_classes, epoch))
コード例 #6
0
ファイル: validate_pred.py プロジェクト: qxdaaaaa/semseg-1
def validate(args):
    init_time = str(int(time.time()))
    if args.vis:
        vis = visdom.Visdom()

    local_path = os.path.expanduser(args.dataset_path)
    if args.dataset == 'SegmPred':
        dst = segmpredLoader(local_path, is_transform=True, split=args.dataset_type)
    elif args.dataset == 'MovingMNIST':
        dst = movingmnistLoader(local_path, is_transform=True, split=args.dataset_type)
    elif args.dataset == 'FreeSpacePred':
        input_channel = 1
        dst = freespacepredLoader(local_path, is_transform=True, split=args.dataset_type)
    else:
        pass
    val_loader = torch.utils.data.DataLoader(dst, batch_size=1, shuffle=False)

    # if os.path.isfile(args.validate_model):
    if args.validate_model != '':
        model = torch.load(args.validate_model)
    else:
        # ---------------for testing SegmPred---------------
        try:
            model = eval(args.structure)(n_classes=args.n_classes, pretrained=args.init_vgg16, input_shape=dst.input_shape, input_channel=input_channel)
        except:
            print('missing structure or not support')
            exit(0)
        if args.validate_model_state_dict != '':
            try:
                model.load_state_dict(torch.load(args.validate_model_state_dict, map_location='cpu'))
            except KeyError:
                print('missing key')
        # ---------------for testing SegmPred---------------
    if args.cuda:
        model.cuda()
    # some model load different mode different performance
    model.eval()
    # model.train()

    gts, preds, errors, imgs_name = [], [], [], []
    for i, (imgs, labels) in enumerate(val_loader):
        print(i)
        # if i==1:
        #     break
        img_path = dst.files[args.dataset_type][i]
        img_name = img_path[img_path.rfind('/', 0, img_path.rfind('/'))+1:]
        img_name = img_name.replace('/', '_')
        imgs_name.append(img_name)
        # print('img_path:', img_path)
        # print('img_name:', img_name)
        #  print(labels.shape)
        #  print(imgs.shape)
        # 将np变量转换为pytorch中的变量
        imgs = Variable(imgs, volatile=True)
        labels = Variable(labels, volatile=True)

        if args.cuda:
            imgs = imgs.cuda()
            labels = labels.cuda()

        # print('imgs.shape', imgs.shape)
        # print('labels.shape', labels.shape)

        outputs = model(imgs)
        # print('outputs.shape', outputs.shape)
        loss = cross_entropy2d(outputs, labels)
        loss_np = loss.cpu().data.numpy()
        loss_np_float = float(loss_np)

        # print('loss_np_float:', loss_np_float)
        errors.append(loss_np_float)

        # 取axis=1中的最大值,outputs的shape为batch_size*n_classes*height*width,
        # 获取max后,返回两个数组,分别是最大值和相应的索引值,这里取索引值为label
        pred = outputs.cpu().data.max(1)[1].numpy()
        gt = labels.cpu().data.numpy()

        if args.save_result:
            if not os.path.exists('/tmp/'+init_time):
                os.mkdir('/tmp/'+init_time)
            pred_labels = outputs.cpu().data.max(1)[1].numpy()
            # print('pred_labels.shape:', pred_labels.shape)
            label_color = dst.decode_segmap(labels.cpu().data.numpy()[0]).transpose(2, 0, 1)
            pred_label_color = dst.decode_segmap(pred_labels[0]).transpose(2, 0, 1)
            # print('label_color.shape:', label_color.shape)
            # print('pred_label_color.shape:', pred_label_color.shape)

            label_color_cv2 = label_color.transpose(1, 2, 0)
            label_color_cv2 = cv2.cvtColor(label_color_cv2, cv2.COLOR_RGB2BGR)
            # print('label_color_cv2.shape:', label_color_cv2.shape)
            # print('label_color_cv2.dtype:', label_color_cv2.dtype)
            # cv2.imshow('label_color_cv2', label_color_cv2)
            # cv2.waitKey()
            cv2.imwrite('/tmp/'+init_time+'/gt_{}.png'.format(img_name), label_color_cv2)

            pred_label_color_cv2 = pred_label_color.transpose(1, 2, 0)
            pred_label_color_cv2 = cv2.cvtColor(pred_label_color_cv2, cv2.COLOR_RGB2BGR)
            cv2.imwrite('/tmp/'+init_time+'/pred_{}.png'.format(img_name), pred_label_color_cv2)

        for gt_, pred_ in zip(gt, pred):
            gts.append(gt_)
            preds.append(pred_)

    # print('errors:', errors)
    # print('imgs_name:', imgs_name)

    errors_indices = np.argsort(errors).tolist()
    print('errors_indices:', errors_indices)
    # for top_i in range(len(errors_indices)):
    for top_i in range(10):
        top_index = errors_indices.index(top_i)
        # print('top_index:', top_index)
        img_name_top = imgs_name[top_index]
        print('img_name_top:', img_name_top)

    score, class_iou = scores(gts, preds, n_class=dst.n_classes)
    for k, v in score.items():
        print(k, v)

    class_iou_list = []
    for i in range(dst.n_classes):
        class_iou_list.append(round(class_iou[i], 2))
        # print(i, round(class_iou[i], 2))
    print('classes:', range(dst.n_classes))
    print('class_iou_list:', class_iou_list)
コード例 #7
0
def train(args):
    now = datetime.datetime.now()
    now_str = '{}-{}-{} {}:{}:{}'.format(now.year, now.month, now.day,
                                         now.hour, now.minute, now.second)
    # print('now:', now)
    # print('now_str:', now_str)
    if args.vis:
        # start visdom and close all window
        vis = visdom.Visdom(env=now_str)
        vis.close()

    class_weight = None
    local_path = os.path.expanduser(args.dataset_path)
    train_dst = None
    val_dst = None
    if args.dataset == 'CamVid':
        train_dst = camvidLoader(local_path,
                                 is_transform=True,
                                 is_augment=args.data_augment,
                                 split='train')
        val_dst = camvidLoader(local_path,
                               is_transform=True,
                               is_augment=False,
                               split='val')

        trainannot_image_dir = os.path.expanduser(
            os.path.join(local_path, "trainannot"))
        trainannot_image_files = [
            os.path.join(trainannot_image_dir, file)
            for file in os.listdir(trainannot_image_dir)
            if file.endswith('.png')
        ]
        if args.class_weighting == 'MFB':
            class_weight = median_frequency_balancing(trainannot_image_files,
                                                      num_classes=12)
            class_weight = torch.tensor(class_weight)
        elif args.class_weighting == 'ENET':
            class_weight = ENet_weighing(trainannot_image_files,
                                         num_classes=12)
            class_weight = torch.tensor(class_weight)
    elif args.dataset == 'CityScapes':
        train_dst = cityscapesLoader(local_path,
                                     is_transform=True,
                                     split='train')
        val_dst = cityscapesLoader(local_path, is_transform=True, split='val')
    elif args.dataset == 'SegmPred':
        train_dst = segmpredLoader(local_path,
                                   is_transform=True,
                                   split='train')
        val_dst = segmpredLoader(local_path, is_transform=True, split='train')
    elif args.dataset == 'MovingMNIST':
        # class_weight = [0.1, 0.5]
        # class_weight = torch.tensor(class_weight)
        train_dst = movingmnistLoader(local_path,
                                      is_transform=True,
                                      split='train')
        val_dst = movingmnistLoader(local_path, is_transform=True, split='val')
    elif args.dataset == 'FreeSpace':
        train_dst = freespaceLoader(local_path,
                                    is_transform=True,
                                    split='train')
        val_dst = freespaceLoader(local_path, is_transform=True, split='val')
    else:
        print('{} dataset does not implement'.format(args.dataset))
        exit(0)

    if args.cuda:
        if class_weight is not None:
            class_weight = class_weight.cuda()
    print('class_weight:', class_weight)

    train_loader = torch.utils.data.DataLoader(train_dst,
                                               batch_size=args.batch_size,
                                               shuffle=True)
    val_loader = torch.utils.data.DataLoader(val_dst,
                                             batch_size=1,
                                             shuffle=True)

    start_epoch = 0
    best_mIoU = 0
    if args.resume_model != '':
        model = torch.load(args.resume_model)
        start_epoch_id1 = args.resume_model.rfind('_')
        start_epoch_id2 = args.resume_model.rfind('.')
        start_epoch = int(args.resume_model[start_epoch_id1 +
                                            1:start_epoch_id2])
    else:
        # model = eval(args.structure)(n_classes=args.n_classes, pretrained=args.init_vgg16)
        try:
            model = eval(args.structure)(n_classes=args.n_classes,
                                         pretrained=args.init_vgg16)
        except:
            print('missing structure or not support')
            exit(0)

        # ---------------for testing SegmPred---------------
        if args.dataset == 'MovingMNIST':
            input_channel = 1 * 9
        elif args.dataset == 'SegmPred':
            input_channel = 19 * 4
        if args.structure == 'drnseg_a_18':
            model = drnseg_a_18(n_classes=args.n_classes,
                                pretrained=args.init_vgg16,
                                input_channel=input_channel)
        # ---------------for testing SegmPred---------------

        if args.resume_model_state_dict != '':
            try:
                # from model save format get useful information, such as miou, epoch
                miou_model_name_str = '_miou_'
                class_model_name_str = '_class_'
                miou_id1 = args.resume_model_state_dict.find(
                    miou_model_name_str) + len(miou_model_name_str)
                miou_id2 = args.resume_model_state_dict.find(
                    class_model_name_str)
                best_mIoU = float(
                    args.resume_model_state_dict[miou_id1:miou_id2])

                start_epoch_id1 = args.resume_model_state_dict.rfind('_')
                start_epoch_id2 = args.resume_model_state_dict.rfind('.')
                start_epoch = int(
                    args.resume_model_state_dict[start_epoch_id1 +
                                                 1:start_epoch_id2])
                pretrained_dict = torch.load(args.resume_model_state_dict,
                                             map_location='cpu')
                model.load_state_dict(pretrained_dict)
            except KeyError:
                print('missing resume_model_state_dict or wrong type')

    if args.cuda:
        model.cuda()
    print('start_epoch:', start_epoch)
    print('best_mIoU:', best_mIoU)

    if args.solver == 'SGD':
        optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad,
                                           model.parameters()),
                                    lr=args.lr,
                                    momentum=0.99,
                                    weight_decay=5e-4)
    elif args.solver == 'RMSprop':
        optimizer = torch.optim.RMSprop(filter(lambda p: p.requires_grad,
                                               model.parameters()),
                                        lr=args.lr,
                                        momentum=0.99,
                                        weight_decay=5e-4)
    elif args.solver == 'Adam':
        optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                            model.parameters()),
                                     lr=args.lr,
                                     weight_decay=5e-4)
    else:
        print('missing solver or not support')
        exit(0)
    # when observerd object dose not decrease scheduler will let the optimizer learing rate decrease
    # scheduler = ReduceLROnPlateau(optimizer, 'min', patience=100, min_lr=1e-10, verbose=True)
    if args.lr_policy == 'Constant':
        scheduler = ConstantLR(optimizer)
    elif args.lr_policy == 'Polynomial':
        scheduler = PolynomialLR(
            optimizer, max_iter=args.training_epoch,
            power=0.9)  # base lr=0.01 power=0.9 like PSPNet
    elif args.lr_policy == 'MultiStep':
        scheduler = MultiStepLR(
            optimizer, milestones=[10, 50, 90],
            gamma=0.1)  # base lr=0.01 power=0.9 like PSPNet

    # scheduler = StepLR(optimizer, step_size=1, gamma=0.1)

    data_count = int(train_dst.__len__() * 1.0 / args.batch_size)
    print('data_count:', data_count)
    # iteration_step = 0
    train_gts, train_preds = [], []
    for epoch in range(start_epoch + 1, args.training_epoch, 1):
        loss_epoch = 0
        scheduler.step()

        optimizer.zero_grad(
        )  # when train next time zero all grad, just acc the grad when the epoch training
        for i, (imgs, labels) in enumerate(train_loader):
            # if i==1:
            #     break
            model.train()

            # 最后的几张图片可能不到batch_size的数量,比如batch_size=4,可能只剩3张
            imgs_batch = imgs.shape[0]
            if imgs_batch != args.batch_size:
                break
            # iteration_step += 1

            imgs = Variable(imgs)
            labels = Variable(labels)

            if args.cuda:
                imgs = imgs.cuda()
                labels = labels.cuda()
            outputs = model(imgs)
            # print('imgs.size:', imgs.size())
            # print('labels.size:', labels.size())
            # print('outputs.size:', outputs.size())

            loss = cross_entropy2d(outputs, labels, weight=class_weight)

            # add grad backward the avg loss
            loss_grad_acc_avg = loss * 1.0 / args.grad_acc_steps
            loss_grad_acc_avg.backward()

            loss_np = loss.cpu().data.numpy()
            loss_epoch += loss_np

            if (i + 1) % args.grad_acc_steps == 0:
                optimizer.step()
                # 一次backward后如果不清零,梯度是累加的
                optimizer.zero_grad()

            # ------------------train metris-------------------------------
            train_pred = outputs.cpu().data.max(1)[1].numpy()
            train_gt = labels.cpu().data.numpy()

            for train_gt_, train_pred_ in zip(train_gt, train_pred):
                train_gts.append(train_gt_)
                train_preds.append(train_pred_)
            # ------------------train metris-------------------------------

            if args.vis and i % 50 == 0:
                pred_labels = outputs.cpu().data.max(1)[1].numpy()
                label_color = train_dst.decode_segmap(
                    labels.cpu().data.numpy()[0]).transpose(2, 0, 1)
                pred_label_color = train_dst.decode_segmap(
                    pred_labels[0]).transpose(2, 0, 1)
                win = 'label_color'
                vis.image(label_color,
                          win=win,
                          opts=dict(title='Gt', caption='Ground Truth'))
                win = 'pred_label_color'
                vis.image(pred_label_color,
                          win=win,
                          opts=dict(title='Pred', caption='Prediction'))

            # 显示一个周期的loss曲线
            if args.vis:
                win = 'loss_iteration'
                loss_np_expand = np.expand_dims(loss_np, axis=0)
                win_res = vis.line(X=np.ones(1) * (i + data_count *
                                                   (epoch - 1) + 1),
                                   Y=loss_np_expand,
                                   win=win,
                                   update='append')
                if win_res != win:
                    vis.line(X=np.ones(1) * (i + data_count * (epoch - 1) + 1),
                             Y=loss_np_expand,
                             win=win,
                             opts=dict(title=win,
                                       xlabel='iteration',
                                       ylabel='loss'))

        # val result on val dataset and pick best to save
        if args.val_interval > 0 and epoch % args.val_interval == 0:
            print('----starting val----')
            model.eval()

            val_gts, val_preds = [], []
            for val_i, (val_imgs, val_labels) in enumerate(val_loader):
                # print(val_i)
                val_imgs = Variable(val_imgs, volatile=True)
                val_labels = Variable(val_labels, volatile=True)

                if args.cuda:
                    val_imgs = val_imgs.cuda()
                    val_labels = val_labels.cuda()

                val_outputs = model(val_imgs)
                val_pred = val_outputs.cpu().data.max(1)[1].numpy()
                val_gt = val_labels.cpu().data.numpy()
                for val_gt_, val_pred_ in zip(val_gt, val_pred):
                    val_gts.append(val_gt_)
                    val_preds.append(val_pred_)

            score, class_iou = scores(val_gts,
                                      val_preds,
                                      n_class=args.n_classes)
            for k, v in score.items():
                print(k, v)
                if k == 'Mean IoU : \t':
                    v_iou = v
                    if v > best_mIoU:
                        best_mIoU = v_iou
                        torch.save(
                            model.state_dict(),
                            '{}_{}_miou_{}_class_{}_{}.pt'.format(
                                args.structure, args.dataset, best_mIoU,
                                args.n_classes, epoch))
                    # 显示校准周期的mIoU
                    if args.vis:
                        win = 'mIoU_epoch'
                        v_iou_expand = np.expand_dims(v_iou, axis=0)
                        win_res = vis.line(X=np.ones(1) * epoch *
                                           args.val_interval,
                                           Y=v_iou_expand,
                                           win=win,
                                           update='append')
                        if win_res != win:
                            vis.line(X=np.ones(1) * epoch * args.val_interval,
                                     Y=v_iou_expand,
                                     win=win,
                                     opts=dict(title=win,
                                               xlabel='epoch',
                                               ylabel='mIoU'))

            for class_i in range(args.n_classes):
                print(class_i, class_iou[class_i])
            print('----ending   val----')

        # 显示多个周期的loss曲线
        loss_avg_epoch = loss_epoch / (data_count * 1.0)
        # print(loss_avg_epoch)
        if args.vis:
            win = 'loss_epoch'
            loss_avg_epoch_expand = np.expand_dims(loss_avg_epoch, axis=0)
            win_res = vis.line(X=np.ones(1) * epoch,
                               Y=loss_avg_epoch_expand,
                               win=win,
                               update='append')
            if win_res != win:
                vis.line(X=np.ones(1) * epoch,
                         Y=loss_avg_epoch_expand,
                         win=win,
                         opts=dict(title=win, xlabel='epoch', ylabel='loss'))

        if args.vis:
            win = 'lr_epoch'
            lr_epoch = np.array(scheduler.get_lr())
            lr_epoch_expand = np.expand_dims(lr_epoch, axis=0)
            win_res = vis.line(X=np.ones(1) * epoch,
                               Y=lr_epoch_expand,
                               win=win,
                               update='append')
            if win_res != win:
                vis.line(X=np.ones(1) * epoch,
                         Y=lr_epoch_expand,
                         win=win,
                         opts=dict(title=win, xlabel='epoch', ylabel='lr'))

        # ------------------train metris-------------------------------
        if args.vis:
            score, class_iou = scores(train_gts,
                                      train_preds,
                                      n_class=args.n_classes)
            for k, v in score.items():
                print(k, v)
                if k == 'Overall Acc : \t':
                    # 显示校准周期的mIoU
                    overall_acc = v
                    if args.vis:
                        win = 'acc_epoch'
                        overall_acc_expand = np.expand_dims(overall_acc,
                                                            axis=0)
                        win_res = vis.line(X=np.ones(1) * epoch,
                                           Y=overall_acc_expand,
                                           win=win,
                                           update='append')
                        if win_res != win:
                            vis.line(X=np.ones(1) * epoch,
                                     Y=overall_acc_expand,
                                     win=win,
                                     opts=dict(title=win,
                                               xlabel='epoch',
                                               ylabel='accuracy'))
            # clear for new training metrics
            train_gts, train_preds = [], []
        # ------------------train metris-------------------------------

        if args.save_model and epoch % args.save_epoch == 0:
            torch.save(
                model.state_dict(),
                '{}_{}_class_{}_{}.pt'.format(args.structure, args.dataset,
                                              args.n_classes, epoch))