Esempio n. 1
0
def validate(val_loader, model, criterion, eval_score=None, print_freq=10):
    # miou part >>>
    confusion_labels = np.arange(0, 19)
    confusion_matrix = RunningConfusionMatrix(confusion_labels)
    # miou part <<<

    batch_time = AverageMeter()
    losses = AverageMeter()
    score = AverageMeter()

    # switch to evaluate mode
    model.eval()

    end = time.time()
    for i, (input, target) in enumerate(val_loader):
        if type(criterion) in [
                torch.nn.modules.loss.L1Loss, torch.nn.modules.loss.MSELoss
        ]:
            target = target.float()
        input = input.cuda()
        target = target.cuda(async=True)
        input_var = torch.autograd.Variable(input, volatile=True)
        target_var = torch.autograd.Variable(target, volatile=True)

        # compute output
        output = model(input_var)
        loss = criterion(output, target_var)
        confusion_matrix.update_matrix(target, output)

        # measure accuracy and record loss
        # prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
        losses.update(loss.data[0], input.size(0))
        if eval_score is not None:
            score.update(eval_score(output, target_var), input.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % print_freq == 0:
            print('Test: [{0}/{1}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Score {score.val:.3f} ({score.avg:.3f})'.format(
                      i,
                      len(val_loader),
                      batch_time=batch_time,
                      loss=losses,
                      score=score),
                  flush=True)

    miou, top_1, top_5 = confusion_matrix.compute_current_mean_intersection_over_union(
    )
    print(' * Score {top1.avg:.3f}'.format(top1=score))
    print(' * mIoU {top1:.3f}'.format(top1=miou))
    confusion_matrix.show_classes()

    return miou
Esempio n. 2
0
def test(eval_data_loader,
         model,
         num_classes,
         output_dir='pred',
         has_gt=True,
         save_vis=False):
    model.eval()
    confusion_labels = np.arange(0, 19)
    confusion_matrix = RunningConfusionMatrix(confusion_labels)
    batch_time = AverageMeter()
    data_time = AverageMeter()
    end = time.time()
    hist = np.zeros((num_classes, num_classes))
    for iter, (image, label, name, size) in enumerate(eval_data_loader):
        data_time.update(time.time() - end)
        image_var = Variable(image, requires_grad=False, volatile=True)
        final = model(image_var)[0]
        _, pred = torch.max(final, 1)
        pred = pred.cpu().data.numpy()
        batch_time.update(time.time() - end)
        prob = torch.exp(final)
        if save_vis:
            save_output_images(pred, name, output_dir, size)
            if prob.size(1) == 2:
                save_prob_images(prob, name, output_dir + '_prob', size)
            else:
                save_colorful_images(pred, name, output_dir + '_color',
                                     CITYSCAPE_PALLETE, size)
        if has_gt:
            # confusion_matrix.update_matrix(label, final)
            label = label.numpy()
            hist += fast_hist(pred.flatten(), label.flatten(), num_classes)
            # print('===> mAP {mAP:.3f}'.format(
            #     mAP=round(np.nanmean(per_class_iu(hist)) * 100, 2)))
        end = time.time()
        print('Eval: [{0}/{1}]\t'
              'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
              'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'.format(
                  iter,
                  len(eval_data_loader),
                  batch_time=batch_time,
                  data_time=data_time))
    ious = per_class_iu(hist) * 100

    if has_gt:  # val
        return round(np.nanmean(ious), 2)
Esempio n. 3
0
def train_seg(args):
    batch_size = args.batch_size
    num_workers = args.workers
    crop_size = args.crop_size
    checkpoint_dir = args.checkpoint_dir

    print(' '.join(sys.argv))

    for k, v in args.__dict__.items():
        print(k, ':', v)

    pretrained_base = args.pretrained_base
    # print(dla_up.__dict__.get(args.arch))
    single_model = dla_up.__dict__.get(args.arch)(classes=args.classes,
                                                  down_ratio=args.down)

    single_model = convert_model(single_model)

    model = torch.nn.DataParallel(single_model).cuda()
    print('model_created')
    if args.edge_weight > 0:
        weight = torch.from_numpy(
            np.array([1, args.edge_weight], dtype=np.float32))
        # criterion = nn.NLLLoss2d(ignore_index=255, weight=weight)
        criterion = nn.NLLLoss2d(ignore_index=-1, weight=weight)
    else:
        # criterion = nn.NLLLoss2d(ignore_index=255)
        criterion = nn.NLLLoss2d(ignore_index=-1)

    criterion.cuda()

    t = []
    if args.random_rotate > 0:
        t.append(transforms.RandomRotate(args.random_rotate))
    if args.random_scale > 0:
        t.append(transforms.RandomScale(args.random_scale))
    t.append(transforms.RandomCrop(crop_size))  #TODO
    if args.random_color:
        t.append(transforms.RandomJitter(0.4, 0.4, 0.4))
    t.extend([transforms.RandomHorizontalFlip()])  #TODO

    t_val = []
    t_val.append(transforms.RandomCrop(crop_size))

    dir_img = '/shared/xudongliu/data/argoverse-tracking/argo_track_all/train/image_02/'
    dir_mask = '/shared/xudongliu/data/argoverse-tracking/argo_track_all/train/' + args.target + '/'
    my_train = BasicDataset(dir_img,
                            dir_mask,
                            transforms.Compose(t),
                            is_train=True)

    val_dir_img = '/shared/xudongliu/data/argoverse-tracking/argo_track_all/val/image_02/'
    val_dir_mask = '/shared/xudongliu/data/argoverse-tracking/argo_track_all/val/' + args.target + '/'
    my_val = BasicDataset(val_dir_img,
                          val_dir_mask,
                          transforms.Compose(t_val),
                          is_train=True)

    train_loader = torch.utils.data.DataLoader(my_train,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=num_workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(
        my_val,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True)  #TODO  batch_size
    print("loader created")
    optimizer = torch.optim.SGD(single_model.optim_parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    lr_scheduler = None  #TODO

    cudnn.benchmark = True
    best_prec1 = 0
    start_epoch = 0

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    confusion_labels = np.arange(0, 5)
    val_confusion_matrix = RunningConfusionMatrix(confusion_labels,
                                                  ignore_label=-1)

    if args.evaluate:
        confusion_labels = np.arange(0, 2)
        val_confusion_matrix = RunningConfusionMatrix(confusion_labels,
                                                      ignore_label=-1,
                                                      reduce=True)
        validate(val_loader,
                 model,
                 criterion,
                 confusion_matrix=val_confusion_matrix)
        return
    writer = SummaryWriter(comment=args.log)

    # TODO test val
    # print("test val")
    # prec1 = validate(val_loader, model, criterion, confusion_matrix=val_confusion_matrix)

    for epoch in range(start_epoch, args.epochs):
        train_confusion_matrix = RunningConfusionMatrix(confusion_labels,
                                                        ignore_label=-1)
        lr = adjust_learning_rate(args, optimizer, epoch)
        print('Epoch: [{0}]\tlr {1:.06f}'.format(epoch, lr))
        # train for one epoch

        train(train_loader,
              model,
              criterion,
              optimizer,
              epoch,
              lr_scheduler,
              confusion_matrix=train_confusion_matrix,
              writer=writer)

        checkpoint_path = os.path.join(checkpoint_dir,
                                       'checkpoint_{}.pth.tar'.format(epoch))
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict()
            },
            is_best=False,
            filename=checkpoint_path)

        # evaluate on validation set
        val_confusion_matrix = RunningConfusionMatrix(confusion_labels,
                                                      ignore_label=-1)
        prec1, loss_val = validate(val_loader,
                                   model,
                                   criterion,
                                   confusion_matrix=val_confusion_matrix)
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        writer.add_scalar('mIoU/epoch', prec1, epoch + 1)
        writer.add_scalar('loss/epoch', loss_val, epoch + 1)

        checkpoint_path = os.path.join(checkpoint_dir,
                                       'checkpoint_{}.pth.tar'.format(epoch))
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
            },
            is_best,
            filename=checkpoint_path)

        if (epoch + 1) % args.save_freq == 0:
            history_path = 'checkpoint_{:03d}.pth.tar'.format(epoch + 1)
            shutil.copyfile(checkpoint_path, history_path)

    writer.close()