Пример #1
0
def main():
    net = BDRAR().cuda()

    #if len(args['snapshot']) > 0:
    #    print 'load snapshot \'%s\' for testing' % args['snapshot']
    net.load_state_dict(torch.load('model.pth', map_location='cuda:0'))

    net.eval()
    with torch.no_grad():
        #for name, root in to_test.iteritems():
        img_list = ["shadow1.jpg", "shadow2.jpg", "shadow3.jpg"]
        #img_list = [img_name for img_name in os.listdir(os.path.join(root, 'ShadowImages')) if
        #            img_name.endswith('.jpg')]
        for idx, img_name in enumerate(img_list):
            print('predicting for %s: %d / %d' %
                  (img_name, idx + 1, len(img_list)))
            #check_mkdir(
            #    os.path.join(ckpt_path, exp_name, '(%s) %s_prediction_%s' % (exp_name, name, args['snapshot'])))
            img = Image.open(img_name)
            w, h = img.size
            img_var = Variable(img_transform(img).unsqueeze(0)).cuda()
            res = net(img_var)
            prediction = np.array(
                transforms.Resize((h, w))(to_pil(res.data.squeeze(0).cpu())))
            #prediction = crf_refine(np.array(img.convert('RGB')), prediction)

            Image.fromarray(prediction).save("out_" + img_name)
Пример #2
0
def main():
    net = BDRAR().cuda()

    if len(args['snapshot']) > 0:
        print 'load snapshot \'%s\' for testing' % args['snapshot']
        net.load_state_dict(
            torch.load(
                os.path.join(ckpt_path, exp_name, args['snapshot'] + '.pth')))

    test_tp, test_fp, test_fn = 0., 0., 0.
    net.eval()
    with torch.no_grad():
        for name, root in to_test.iteritems():
            img_list = sorted([
                img_name
                for img_name in os.listdir(os.path.join(root, 'images'))
                if img_name.endswith('.jpg')
            ])
            gt_list = sorted([
                img_name
                for img_name in os.listdir(os.path.join(root, 'masks'))
                if img_name.endswith('.png')
            ])
            assert len(img_list) == len(gt_list), "%d != %d".format(
                len(img_list), len(gt_list))

            for idx, (img_name, gt_name) in enumerate(zip(img_list, gt_list)):
                print 'predicting for %s: %d / %d' % (name, idx + 1,
                                                      len(img_list))
                check_mkdir(
                    os.path.join(
                        ckpt_path, exp_name, '(%s) %s_prediction_%s' %
                        (exp_name, name, args['snapshot'])))
                img = Image.open(os.path.join(root, 'images', img_name))
                gt = Image.open(os.path.join(root, 'masks', gt_name))
                w, h = img.size
                img_var = Variable(img_transform(img).unsqueeze(0)).cuda()
                res = net(img_var)
                prediction = np.array(
                    transforms.Resize(
                        (h, w))(to_pil(res.data.squeeze(0).cpu())))
                prediction = crf_refine(np.array(img.convert('RGB')),
                                        prediction)

                pred_binary = (np.array(prediction) > 0).astype(int)
                gt = (np.array(gt) > 0).astype(int)
                tp, fp, fn = compute_stats(pred_binary, gt)
                test_tp += tp
                test_fp += fp
                test_fn += fn

    test_prec = test_tp / (test_tp + test_fp)
    test_rec = test_tp / (test_tp + test_fn)
    test_f1 = 2 * test_prec * test_rec / (test_prec + test_rec)
    print('Testing percision: %.2f, recall: %.2f, f1: %.2f' %
          (test_prec, test_rec, test_f1))
Пример #3
0
def main():
    net = BDRAR().cuda()

    if len(args['snapshot']) > 0:
        print 'load snapshot \'%s\' for testing' % args['snapshot']
        net.load_state_dict(
            torch.load(
                os.path.join(ckpt_path, exp_name, args['snapshot'] + '.pth')))

    net.eval()
    with torch.no_grad():
        for name, root in to_test.iteritems():
            img_list = [
                img_name
                for img_name in os.listdir(os.path.join(root, 'images'))
                if img_name.endswith('.jpg')
            ]
            for idx, img_name in enumerate(img_list):
                print 'predicting for %s: %d / %d' % (name, idx + 1,
                                                      len(img_list))
                check_mkdir(
                    os.path.join(
                        ckpt_path, exp_name, '(%s) %s_prediction_%s' %
                        (exp_name, name, args['snapshot'])))
                img = Image.open(os.path.join(root, 'images', img_name))
                w, h = img.size
                img_var = Variable(img_transform(img).unsqueeze(0)).cuda()
                res = net(img_var)
                prediction = np.array(
                    transforms.Resize(
                        (h, w))(to_pil(res.data.squeeze(0).cpu())))
                prediction = crf_refine(np.array(img.convert('RGB')),
                                        prediction)

                Image.fromarray(prediction).save(
                    os.path.join(
                        ckpt_path, exp_name, '(%s) %s_prediction_%s' %
                        (exp_name, name, args['snapshot']),
                        os.path.splitext(img_name)[0] + ".png"))
Пример #4
0
def main():
    net = BDRAR().cuda().train()

    optimizer = optim.SGD([{
        'params': [
            param
            for name, param in net.named_parameters() if name[-4:] == 'bias'
        ],
        'lr':
        2 * args['lr']
    }, {
        'params': [
            param
            for name, param in net.named_parameters() if name[-4:] != 'bias'
        ],
        'lr':
        args['lr'],
        'weight_decay':
        args['weight_decay']
    }],
                          momentum=args['momentum'])

    if len(args['snapshot']) > 0:
        print('training resumes from \'%s\'' % args['snapshot'])
        net.load_state_dict(
            torch.load(
                os.path.join(ckpt_path, exp_name, args['snapshot'] + '.pth')))
        optimizer.load_state_dict(
            torch.load(
                os.path.join(ckpt_path, exp_name,
                             args['snapshot'] + '_optim.pth')))
        optimizer.param_groups[0]['lr'] = 2 * args['lr']
        optimizer.param_groups[1]['lr'] = args['lr']

    check_mkdir(ckpt_path)
    check_mkdir(os.path.join(ckpt_path, exp_name))
    open(log_path, 'w').write(str(args) + '\n\n')
    train(net, optimizer)