コード例 #1
0
def main():
    net = R3Net().cuda()

    print 'load snapshot \'%s\' for testing' % args['snapshot']
    net.load_state_dict(
        torch.load(os.path.join(ckpt_path, exp_name,
                                args['snapshot'] + '.pth')))
    net.eval()

    results = {}

    with torch.no_grad():

        for name, root in to_test.iteritems():

            precision_record, recall_record, = [
                AvgMeter() for _ in range(256)
            ], [AvgMeter() for _ in range(256)]
            mae_record = AvgMeter()

            if args['save_results']:
                check_mkdir(
                    os.path.join(
                        ckpt_path, exp_name,
                        '(%s) %s_%s' % (exp_name, name, args['snapshot'])))

            img_list = os.listdir(root)
            print img_list
            for idx, img_name in enumerate(img_list):
                print 'predicting for %s: %d / %d' % (name, idx + 1,
                                                      len(img_list))
                #print img_name
                img_path = os.path.join(root, img_name)
                img = Image.open(img_path).convert('RGB')
                img_var = Variable(img_transform(img).unsqueeze(0),
                                   volatile=True).cuda()
                prediction = net(img_var)
                prediction = np.array(to_pil(prediction.data.squeeze(0).cpu()))

                if args['crf_refine']:
                    prediction = crf_refine(np.array(img), prediction)

                #gt = np.array(Image.open(os.path.join(root+"/masks", img_name )).convert('L'))
                #precision, recall, mae = cal_precision_recall_mae(prediction, gt)
                #for pidx, pdata in enumerate(zip(precision, recall)):
                #    p, r = pdata
                #    precision_record[pidx].update(p)
                #    recall_record[pidx].update(r)
                #mae_record.update(mae)

                if args['save_results']:
                    Image.fromarray(prediction).save(
                        os.path.join(
                            ckpt_path, exp_name, '%s_%s_%s' %
                            (exp_name, name, args['snapshot'] + "kaist_test"),
                            img_name))
コード例 #2
0
def main():
    net = R3Net().cuda()

    print
    'load snapshot \'%s\' for testing' % args['snapshot']
    net.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'] + '.pth')))
    net.eval()

    results = {}

    for name, root in to_test.iteritems():

        precision_record, recall_record, = [AvgMeter() for _ in range(256)], [AvgMeter() for _ in range(256)]
        mae_record = AvgMeter()

        img_list = [os.path.splitext(f)[0] for f in os.listdir(root) if f.endswith('.jpg')]
        for idx, img_name in enumerate(img_list):
            print
            'predicting for %s: %d / %d' % (name, idx + 1, len(img_list))
            check_mkdir(os.path.join(ckpt_path, exp_name, '(%s) %s_%s' % (exp_name, name, args['snapshot'])))

            img = Image.open(os.path.join(root, img_name + '.jpg')).convert('RGB')
            img_var = Variable(img_transform(img).unsqueeze(0), volatile=True).cuda()
            prediction = net(img_var)
            prediction = np.array(to_pil(prediction.data.squeeze(0).cpu()))

            if args['crf_refine']:
                prediction = crf_refine(np.array(img), prediction)

            gt = np.array(Image.open(os.path.join(root, img_name + '.png')).convert('L'))
            precision, recall, mae = cal_precision_recall_mae(prediction, gt)
            for pidx, pdata in enumerate(zip(precision, recall)):
                p, r = pdata
                precision_record[pidx].update(p)
                recall_record[pidx].update(r)
            mae_record.update(mae)

            if args['save_results']:
                Image.fromarray(prediction).save(os.path.join(ckpt_path, exp_name, '(%s) %s_%s' % (
                    exp_name, name, args['snapshot']), img_name + '.png'))

        fmeasure = cal_fmeasure([precord.avg for precord in precision_record],
                                [rrecord.avg for rrecord in recall_record])

        results[name] = {'fmeasure': fmeasure, 'mae': mae_record.avg}

    print
    'test results:'
    print
    results
コード例 #3
0
ファイル: train_sasc2.py プロジェクト: piggy2008/STDistill
def main():
    net = R3Net(motion=args['motion'],
                se_layer=args['se_layer'],
                dilation=args['dilation'],
                basic_model=args['basic_model']).cuda().train()

    # fix_parameters(net.named_parameters())
    optimizer = optim.SGD([{
        'params': [
            param
            for name, param in net.named_parameters() if name[-4:] == 'bias'
        ],
        'lr':
        2 * args['lr']
    }, {
        'params': [
            param
            for name, param in net.named_parameters() if name[-4:] != 'bias'
        ],
        'lr':
        args['lr'],
        'weight_decay':
        args['weight_decay']
    }],
                          momentum=args['momentum'])

    if len(args['snapshot']) > 0:
        print('training resumes from ' + args['snapshot'])
        net.load_state_dict(
            torch.load(
                os.path.join(ckpt_path, exp_name, args['snapshot'] + '.pth')))
        optimizer.load_state_dict(
            torch.load(
                os.path.join(ckpt_path, exp_name,
                             args['snapshot'] + '_optim.pth')))
        optimizer.param_groups[0]['lr'] = 2 * args['lr']
        optimizer.param_groups[1]['lr'] = args['lr']

    if len(args['pretrain']) > 0:
        print('pretrain model from ' + args['pretrain'])
        net = load_part_of_model(net, args['pretrain'], device_id=device_id)

    check_mkdir(ckpt_path)
    check_mkdir(os.path.join(ckpt_path, exp_name))
    open(log_path, 'w').write(str(args) + '\n\n')
    train(net, optimizer)
コード例 #4
0
ファイル: train.py プロジェクト: xw-hu/R3Net
def main():
    net = R3Net().cuda().train()

    optimizer = optim.SGD([
        {'params': [param for name, param in net.named_parameters() if name[-4:] == 'bias'],
         'lr': 2 * args['lr']},
        {'params': [param for name, param in net.named_parameters() if name[-4:] != 'bias'],
         'lr': args['lr'], 'weight_decay': args['weight_decay']}
    ], momentum=args['momentum'])

    if len(args['snapshot']) > 0:
        print('training resumes from ' + args['snapshot'])
        net.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'] + '.pth')))
        optimizer.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'] + '_optim.pth')))
        optimizer.param_groups[0]['lr'] = 2 * args['lr']
        optimizer.param_groups[1]['lr'] = args['lr']

    check_mkdir(ckpt_path)
    check_mkdir(os.path.join(ckpt_path, exp_name))
    open(log_path, 'w').write(str(args) + '\n\n')
    train(net, optimizer)
コード例 #5
0
ファイル: infer.py プロジェクト: piggy2008/R3Net
def main():
    net = R3Net(motion='', se_layer=False, attention=False, basic_model='resnext101')

    print ('load snapshot \'%s\' for testing' % args['snapshot'])
    net.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'] + '.pth'), map_location='cuda:0'))
    net.eval()
    net.cuda()
    results = {}

    with torch.no_grad():

        for name, root in to_test.items():

            precision_record, recall_record, = [AvgMeter() for _ in range(256)], [AvgMeter() for _ in range(256)]
            mae_record = AvgMeter()

            if args['save_results']:
                check_mkdir(os.path.join(ckpt_path, exp_name, '(%s) %s_%s' % (exp_name, name, args['snapshot'])))
            img_list = [i_id.strip() for i_id in open(imgs_path)]
            # img_list = [os.path.splitext(f)[0] for f in os.listdir(root) if f.endswith('.jpg')]
            for idx, img_names in enumerate(img_list):
                print ('predicting for %s: %d / %d' % (name, idx + 1, len(img_list)))
                img_seq = img_names.split(',')
                img_var = []
                for img_name in img_seq:
                    if name == 'VOS' or name == 'DAVSOD':
                        img = Image.open(os.path.join(root, img_name + '.png')).convert('RGB')
                    else:
                        img = Image.open(os.path.join(root, img_name + '.jpg')).convert('RGB')
                    shape = img.size
                    img = img.resize(args['input_size'])
                    img_var.append(Variable(img_transform(img).unsqueeze(0), volatile=True).cuda())
                # if name == 'VOS':
                #     img = Image.open(os.path.join(root, img_name + '.png')).convert('RGB')
                # else:
                #     img = Image.open(os.path.join(root, img_name + '.jpg')).convert('RGB')
                # shape = img.size
                # img = img.resize(args['input_size'])
                # img_var = Variable(img_transform(img).unsqueeze(0), volatile=True).cuda()
                img_var = torch.cat(img_var, dim=0)
                start = time.time()
                prediction = net(img_var)
                end = time.time()
                print ('running time:', (end - start))
                precision = to_pil(prediction.data.squeeze(0).cpu())
                precision = precision.resize(shape)
                prediction = np.array(precision)
                prediction = prediction.astype('float')
                prediction = MaxMinNormalization(prediction, prediction.max(), prediction.min()) * 255.0
                prediction = prediction.astype('uint8')
                if args['crf_refine']:
                    prediction = crf_refine(np.array(img), prediction)

                gt = np.array(Image.open(os.path.join(gt_root, img_name + '.png')).convert('L'))
                precision, recall, mae = cal_precision_recall_mae(prediction, gt)
                for pidx, pdata in enumerate(zip(precision, recall)):
                    p, r = pdata
                    precision_record[pidx].update(p)
                    recall_record[pidx].update(r)
                mae_record.update(mae)

                if args['save_results']:
                    folder, sub_name = os.path.split(img_name)
                    save_path = os.path.join(ckpt_path, exp_name, '(%s) %s_%s' % (exp_name, name, args['snapshot']), folder)
                    if not os.path.exists(save_path):
                        os.makedirs(save_path)
                    Image.fromarray(prediction).save(os.path.join(save_path, sub_name + '.png'))

            fmeasure = cal_fmeasure([precord.avg for precord in precision_record],
                                    [rrecord.avg for rrecord in recall_record])

            results[name] = {'fmeasure': fmeasure, 'mae': mae_record.avg}

    print ('test results:')
    print (results)
コード例 #6
0
    x = (x - Min) / (Max - Min)
    return x


def load_part_of_model(new_model, src_model_path, device_id=0):
    src_model = torch.load(src_model_path,
                           map_location='cuda:' + str(device_id))
    m_dict = new_model.state_dict()
    for k in src_model.keys():
        print(k)
        param = src_model.get(k)
        m_dict[k].data = param

    new_model.load_state_dict(m_dict)
    return new_model


if __name__ == '__main__':
    ckpt_path = './ckpt'
    exp_name = 'VideoSaliency_2019-05-14 17:13:16'

    args = {
        'snapshot': '30000',  # your snapshot filename (exclude extension name)
        'crf_refine': False,  # whether to use crf to refine results
        'save_results': True,  # whether to save the resulting masks
        'input_size': (473, 473)
    }
    src_model_path = os.path.join(ckpt_path, exp_name,
                                  args['snapshot'] + '.pth')
    net = R3Net(motion='GRU')
    net = load_part_of_model(net, src_model_path)