Пример #1
0
    def boundary_detection(self,image):

        #build testloader
        if not isdir(os.path.join(script_dir,"test")):
            os.makedirs(os.path.join(script_dir,"test"))
        width_p = 481
        height_p = 321
        w_or,h_or = image.shape[0:2]

        image = cv2.resize(image, (width_p, height_p))
        cv2.imwrite(os.path.join(script_dir,"test","tmp.jpg"),image)
        fil = open(os.path.join(script_dir,"test.lst"), "a+")
        fil.write('test/tmp.jpg\n')

        test_dataset = BSDS_RCFLoader(root=os.path.join(script_dir), split="test")
        # print(test_dataset.filelist)
        test_loader = DataLoader(
            test_dataset, batch_size=1,
            num_workers=8, drop_last=True, shuffle=False)


        for idx, image in enumerate(test_loader):
            image = image.cuda()
            _, _, H, W = image.shape
            results = self.model(image)
            refine_img,lsd_grd = self.refinement_gradVConturMulti(results)
            res_fusion = 255*torch.squeeze(results[-1].detach()).cpu().numpy()
            graident = 255*torch.squeeze(results[0].detach()).cpu().numpy()
            results_all = torch.zeros((len(results), 1, H, W))
            if SHOW_RESULTS:
                plt.imshow(res_fusion)
                plt.show()

            # cv2.imshow('Fusion', res_fusion.astype('uint8'))
            # cv2.waitKey(0)
            # cv2.imshow('Graident',torch.squeeze(results[0].detach()).cpu().numpy())
            # cv2.waitKey(0)
            # cv2.imshow('refine_img', refine_img.astype('uint8'))
            # cv2.waitKey(0)
            # cv2.imshow('lsd_grd', lsd_grd)
            # cv2.waitKey(0)


        open(os.path.join(script_dir,"test.lst"), 'w').close()

        return res_fusion,graident,refine_img,lsd_grd
Пример #2
0
def make_optim(model, lr):
    optim = torch.optim.SGD(params=model.parameters(),
                            lr=lr,
                            momentum=0.9,
                            weight_decay=5e-4)
    return optim


def save_ckpt(model, name):
    print('saving checkpoint ... {}'.format(name), flush=True)
    if not os.path.isdir('ckpt'):
        os.mkdir('ckpt')
    torch.save(model.state_dict(), os.path.join('ckpt', '{}.pth'.format(name)))


train_dataset = BSDS_RCFLoader(split="train")
# test_dataset = BSDS_RCFLoader(split="test")
train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          num_workers=8,
                          drop_last=True,
                          shuffle=True)


def cross_entropy_loss_RCF(prediction, label):
    label = label.long()
    mask = label.float()
    num_positive = torch.sum((mask == 1).float()).float()
    num_negative = torch.sum((mask == 0).float()).float()

    mask[mask == 1] = 1.0 * num_negative / (num_positive + num_negative)
Пример #3
0
def main():
    args.cuda = True
    # dataset
    train_dataset = BSDS_RCFLoader(root=args.dataset, split="train")
    test_dataset = BSDS_RCFLoader(root=args.dataset, split="test")
    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              num_workers=8,
                              drop_last=True,
                              shuffle=True)
    test_loader = DataLoader(test_dataset,
                             batch_size=args.batch_size,
                             num_workers=8,
                             drop_last=True,
                             shuffle=False)
    with open('data/HED-BSDS_PASCAL/test.lst', 'r') as f:
        test_list = f.readlines()
    test_list = [split(i.rstrip())[1] for i in test_list]
    assert len(test_list) == len(test_loader), "%d vs %d" % (len(test_list),
                                                             len(test_loader))

    # model
    model = RCF()
    model.cuda()
    model.apply(weights_init)
    load_vgg16pretrain(model)
    if args.resume:
        if isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}'".format(args.resume))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    #tune lr
    net_parameters_id = {}
    net = model
    for pname, p in net.named_parameters():
        if pname in [
                'conv1_1.weight', 'conv1_2.weight', 'conv2_1.weight',
                'conv2_2.weight', 'conv3_1.weight', 'conv3_2.weight',
                'conv3_3.weight', 'conv4_1.weight', 'conv4_2.weight',
                'conv4_3.weight'
        ]:
            print(pname, 'lr:1 de:1')
            if 'conv1-4.weight' not in net_parameters_id:
                net_parameters_id['conv1-4.weight'] = []
            net_parameters_id['conv1-4.weight'].append(p)
        elif pname in [
                'conv1_1.bias', 'conv1_2.bias', 'conv2_1.bias', 'conv2_2.bias',
                'conv3_1.bias', 'conv3_2.bias', 'conv3_3.bias', 'conv4_1.bias',
                'conv4_2.bias', 'conv4_3.bias'
        ]:
            print(pname, 'lr:2 de:0')
            if 'conv1-4.bias' not in net_parameters_id:
                net_parameters_id['conv1-4.bias'] = []
            net_parameters_id['conv1-4.bias'].append(p)
        elif pname in ['conv5_1.weight', 'conv5_2.weight', 'conv5_3.weight']:
            print(pname, 'lr:100 de:1')
            if 'conv5.weight' not in net_parameters_id:
                net_parameters_id['conv5.weight'] = []
            net_parameters_id['conv5.weight'].append(p)
        elif pname in ['conv5_1.bias', 'conv5_2.bias', 'conv5_3.bias']:
            print(pname, 'lr:200 de:0')
            if 'conv5.bias' not in net_parameters_id:
                net_parameters_id['conv5.bias'] = []
            net_parameters_id['conv5.bias'].append(p)
        elif pname in [
                'conv1_1_down.weight', 'conv1_2_down.weight',
                'conv2_1_down.weight', 'conv2_2_down.weight',
                'conv3_1_down.weight', 'conv3_2_down.weight',
                'conv3_3_down.weight', 'conv4_1_down.weight',
                'conv4_2_down.weight', 'conv4_3_down.weight',
                'conv5_1_down.weight', 'conv5_2_down.weight',
                'conv5_3_down.weight'
        ]:
            print(pname, 'lr:0.1 de:1')
            if 'conv_down_1-5.weight' not in net_parameters_id:
                net_parameters_id['conv_down_1-5.weight'] = []
            net_parameters_id['conv_down_1-5.weight'].append(p)
        elif pname in [
                'conv1_1_down.bias', 'conv1_2_down.bias', 'conv2_1_down.bias',
                'conv2_2_down.bias', 'conv3_1_down.bias', 'conv3_2_down.bias',
                'conv3_3_down.bias', 'conv4_1_down.bias', 'conv4_2_down.bias',
                'conv4_3_down.bias', 'conv5_1_down.bias', 'conv5_2_down.bias',
                'conv5_3_down.bias'
        ]:
            print(pname, 'lr:0.2 de:0')
            if 'conv_down_1-5.bias' not in net_parameters_id:
                net_parameters_id['conv_down_1-5.bias'] = []
            net_parameters_id['conv_down_1-5.bias'].append(p)
        elif pname in [
                'score_dsn1.weight', 'score_dsn2.weight', 'score_dsn3.weight',
                'score_dsn4.weight', 'score_dsn5.weight'
        ]:
            print(pname, 'lr:0.01 de:1')
            if 'score_dsn_1-5.weight' not in net_parameters_id:
                net_parameters_id['score_dsn_1-5.weight'] = []
            net_parameters_id['score_dsn_1-5.weight'].append(p)
        elif pname in [
                'score_dsn1.bias', 'score_dsn2.bias', 'score_dsn3.bias',
                'score_dsn4.bias', 'score_dsn5.bias'
        ]:
            print(pname, 'lr:0.02 de:0')
            if 'score_dsn_1-5.bias' not in net_parameters_id:
                net_parameters_id['score_dsn_1-5.bias'] = []
            net_parameters_id['score_dsn_1-5.bias'].append(p)
        elif pname in ['score_final.weight']:
            print(pname, 'lr:0.001 de:1')
            if 'score_final.weight' not in net_parameters_id:
                net_parameters_id['score_final.weight'] = []
            net_parameters_id['score_final.weight'].append(p)
        elif pname in ['score_final.bias']:
            print(pname, 'lr:0.002 de:0')
            if 'score_final.bias' not in net_parameters_id:
                net_parameters_id['score_final.bias'] = []
            net_parameters_id['score_final.bias'].append(p)

    optimizer = torch.optim.SGD([
        {
            'params': net_parameters_id['conv1-4.weight'],
            'lr': args.lr * 1,
            'weight_decay': args.weight_decay
        },
        {
            'params': net_parameters_id['conv1-4.bias'],
            'lr': args.lr * 2,
            'weight_decay': 0.
        },
        {
            'params': net_parameters_id['conv5.weight'],
            'lr': args.lr * 100,
            'weight_decay': args.weight_decay
        },
        {
            'params': net_parameters_id['conv5.bias'],
            'lr': args.lr * 200,
            'weight_decay': 0.
        },
        {
            'params': net_parameters_id['conv_down_1-5.weight'],
            'lr': args.lr * 0.1,
            'weight_decay': args.weight_decay
        },
        {
            'params': net_parameters_id['conv_down_1-5.bias'],
            'lr': args.lr * 0.2,
            'weight_decay': 0.
        },
        {
            'params': net_parameters_id['score_dsn_1-5.weight'],
            'lr': args.lr * 0.01,
            'weight_decay': args.weight_decay
        },
        {
            'params': net_parameters_id['score_dsn_1-5.bias'],
            'lr': args.lr * 0.02,
            'weight_decay': 0.
        },
        {
            'params': net_parameters_id['score_final.weight'],
            'lr': args.lr * 0.001,
            'weight_decay': args.weight_decay
        },
        {
            'params': net_parameters_id['score_final.bias'],
            'lr': args.lr * 0.002,
            'weight_decay': 0.
        },
    ],
                                lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    scheduler = lr_scheduler.StepLR(optimizer,
                                    step_size=args.stepsize,
                                    gamma=args.gamma)

    # optimizer = torch.optim.Adam([
    #         {'params': net_parameters_id['conv1-4.weight']      , 'lr': args.lr*1    , 'weight_decay': args.weight_decay},
    #         {'params': net_parameters_id['conv1-4.bias']        , 'lr': args.lr*2    , 'weight_decay': 0.},
    #         {'params': net_parameters_id['conv5.weight']        , 'lr': args.lr*100  , 'weight_decay': args.weight_decay},
    #         {'params': net_parameters_id['conv5.bias']          , 'lr': args.lr*200  , 'weight_decay': 0.},
    #         {'params': net_parameters_id['conv_down_1-5.weight'], 'lr': args.lr*0.1  , 'weight_decay': args.weight_decay},
    #         {'params': net_parameters_id['conv_down_1-5.bias']  , 'lr': args.lr*0.2  , 'weight_decay': 0.},
    #         {'params': net_parameters_id['score_dsn_1-5.weight'], 'lr': args.lr*0.01 , 'weight_decay': args.weight_decay},
    #         {'params': net_parameters_id['score_dsn_1-5.bias']  , 'lr': args.lr*0.02 , 'weight_decay': 0.},
    #         {'params': net_parameters_id['score_final.weight']  , 'lr': args.lr*0.001, 'weight_decay': args.weight_decay},
    #         {'params': net_parameters_id['score_final.bias']    , 'lr': args.lr*0.002, 'weight_decay': 0.},
    #     ], lr=args.lr, betas=(0.9, 0.99), weight_decay=args.weight_decay)
    # scheduler = lr_scheduler.StepLR(optimizer, step_size=args.stepsize, gamma=args.gamma)

    # log
    log = Logger(join(TMP_DIR, '%s-%d-log.txt' % ('sgd', args.lr)))
    sys.stdout = log

    train_loss = []
    train_loss_detail = []
    for epoch in range(args.start_epoch, args.maxepoch):
        if epoch == 0:
            print("Performing initial testing...")
            multiscale_test(model,
                            test_loader,
                            epoch=epoch,
                            test_list=test_list,
                            save_dir=join(TMP_DIR, 'initial-testing-record'))

        tr_avg_loss, tr_detail_loss = train(
            train_loader,
            model,
            optimizer,
            epoch,
            save_dir=join(TMP_DIR, 'epoch-%d-training-record' % epoch))
        test(model,
             test_loader,
             epoch=epoch,
             test_list=test_list,
             save_dir=join(TMP_DIR, 'epoch-%d-testing-record-view' % epoch))
        multiscale_test(model,
                        test_loader,
                        epoch=epoch,
                        test_list=test_list,
                        save_dir=join(TMP_DIR,
                                      'epoch-%d-testing-record' % epoch))
        log.flush()  # write log
        # Save checkpoint
        save_file = os.path.join(TMP_DIR,
                                 'checkpoint_epoch{}.pth'.format(epoch))
        save_checkpoint(
            {
                'epoch': epoch,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict()
            },
            filename=save_file)
        scheduler.step()  # will adjust learning rate
        # save train/val loss/accuracy, save every epoch in case of early stop
        train_loss.append(tr_avg_loss)
        train_loss_detail += tr_detail_loss
Пример #4
0
    os.mkdir(png_folder)
    os.mkdir(mat_folder)
except Exception:
    print('dir already exist....')
    pass

model = models.resnet101(pretrained=False)
model = torch.nn.DataParallel(model, device_ids=(0, ))
model.to(device)
model.eval()

#resume..
checkpoint = torch.load(resume)
model.load_state_dict(checkpoint)

test_dataset = BSDS_RCFLoader(split="test")
test_loader = DataLoader(test_dataset,
                         batch_size=batch_size,
                         num_workers=1,
                         drop_last=True,
                         shuffle=False)

with torch.no_grad():
    for i, (image, ori, img_files) in enumerate(test_loader):
        h, w = ori.size()[2:]
        image = image.to(device)

        outs = model(image, (h, w))

        fuse = outs[-1].squeeze().detach().cpu().numpy()
        outs.append(ori)
Пример #5
0
def main():
    args.cuda = True
    # dataset
    train_dataset = BSDS_RCFLoader(root=args.dataset, split="train")
    test_dataset = BSDS_RCFLoader(root=args.dataset, split="test")
    train_loader = DataLoader(
        train_dataset, batch_size=args.batch_size,
        num_workers=8, drop_last=True,shuffle=True)
    test_loader = DataLoader(
        test_dataset, batch_size=args.batch_size,
        num_workers=8, drop_last=True,shuffle=False)


    with open('data/HED-BSDS_PASCAL/test.lst', 'r') as f:
    # with open('data/HED-BSDS/test.lst', 'r') as f:
        test_list = f.readlines()
    test_list = [split(i.rstrip())[1] for i in test_list]
    assert len(test_list) == len(test_loader), "%d vs %d" % (len(test_list), len(test_loader))

    # model
    model = RCF()
    model.cuda()
    model.apply(weights_init)
    load_vgg16pretrain(model,vgg_model_name)

    if args.resume:
        if isfile(args.resume): 
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}'"
                  .format(args.resume))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    
    #tune lr
    net_parameters_id = {}
    net = model
    for pname, p in net.named_parameters():
        if pname in ['conv1_1.weight','conv1_2.weight',
                     'conv2_1.weight','conv2_2.weight',
                     'conv3_1.weight','conv3_2.weight','conv3_3.weight',
                     'conv4_1.weight','conv4_2.weight','conv4_3.weight']:
            print(pname, 'lr:1 de:1')
            if 'conv1-4.weight' not in net_parameters_id:
                net_parameters_id['conv1-4.weight'] = []
            net_parameters_id['conv1-4.weight'].append(p)
        elif pname in ['conv1_1.bias','conv1_2.bias',
                       'conv2_1.bias','conv2_2.bias',
                       'conv3_1.bias','conv3_2.bias','conv3_3.bias',
                       'conv4_1.bias','conv4_2.bias','conv4_3.bias']:
            print(pname, 'lr:2 de:0')
            if 'conv1-4.bias' not in net_parameters_id:
                net_parameters_id['conv1-4.bias'] = []
            net_parameters_id['conv1-4.bias'].append(p)
        elif pname in ['conv5_1.weight','conv5_2.weight','conv5_3.weight']:
            print(pname, 'lr:100 de:1')
            if 'conv5.weight' not in net_parameters_id:
                net_parameters_id['conv5.weight'] = []
            net_parameters_id['conv5.weight'].append(p)
        elif pname in ['conv5_1.bias','conv5_2.bias','conv5_3.bias'] :
            print(pname, 'lr:200 de:0')
            if 'conv5.bias' not in net_parameters_id:
                net_parameters_id['conv5.bias'] = []
            net_parameters_id['conv5.bias'].append(p)
        elif pname in ['conv1_1_down.weight','conv1_2_down.weight',
                       'conv2_1_down.weight','conv2_2_down.weight',
                       'conv3_1_down.weight','conv3_2_down.weight','conv3_3_down.weight',
                       'conv4_1_down.weight','conv4_2_down.weight','conv4_3_down.weight',
                       'conv5_1_down.weight','conv5_2_down.weight','conv5_3_down.weight']:
            print(pname, 'lr:0.1 de:1')
            if 'conv_down_1-5.weight' not in net_parameters_id:
                net_parameters_id['conv_down_1-5.weight'] = []
            net_parameters_id['conv_down_1-5.weight'].append(p)
        elif pname in ['conv1_1_down.bias','conv1_2_down.bias',
                       'conv2_1_down.bias','conv2_2_down.bias',
                       'conv3_1_down.bias','conv3_2_down.bias','conv3_3_down.bias',
                       'conv4_1_down.bias','conv4_2_down.bias','conv4_3_down.bias',
                       'conv5_1_down.bias','conv5_2_down.bias','conv5_3_down.bias']:
            print(pname, 'lr:0.2 de:0')
            if 'conv_down_1-5.bias' not in net_parameters_id:
                net_parameters_id['conv_down_1-5.bias'] = []
            net_parameters_id['conv_down_1-5.bias'].append(p)
        elif pname in ['score_dsn1.weight','score_dsn2.weight','score_dsn3.weight',
                       'score_dsn4.weight','score_dsn5.weight']:
            print(pname, 'lr:0.01 de:1')
            if 'score_dsn_1-5.weight' not in net_parameters_id:
                net_parameters_id['score_dsn_1-5.weight'] = []
            net_parameters_id['score_dsn_1-5.weight'].append(p)
        elif pname in ['score_dsn1.bias','score_dsn2.bias','score_dsn3.bias',
                       'score_dsn4.bias','score_dsn5.bias']:
            print(pname, 'lr:0.02 de:0')
            if 'score_dsn_1-5.bias' not in net_parameters_id:
                net_parameters_id['score_dsn_1-5.bias'] = []
            net_parameters_id['score_dsn_1-5.bias'].append(p)
        elif pname in ['score_final.weight']:
            print(pname, 'lr:0.001 de:1')
            if 'score_final.weight' not in net_parameters_id:
                net_parameters_id['score_final.weight'] = []
            net_parameters_id['score_final.weight'].append(p)
        elif pname in ['score_final.bias']:
            print(pname, 'lr:0.002 de:0')
            if 'score_final.bias' not in net_parameters_id:
                net_parameters_id['score_final.bias'] = []
            net_parameters_id['score_final.bias'].append(p)

    optimizer = torch.optim.SGD([
            {'params': net_parameters_id['conv1-4.weight']      , 'lr': args.lr*1    , 'weight_decay': args.weight_decay},
            {'params': net_parameters_id['conv1-4.bias']        , 'lr': args.lr*2    , 'weight_decay': 0.},
            {'params': net_parameters_id['conv5.weight']        , 'lr': args.lr*100  , 'weight_decay': args.weight_decay},
            {'params': net_parameters_id['conv5.bias']          , 'lr': args.lr*200  , 'weight_decay': 0.},
            {'params': net_parameters_id['conv_down_1-5.weight'], 'lr': args.lr*0.1  , 'weight_decay': args.weight_decay},
            {'params': net_parameters_id['conv_down_1-5.bias']  , 'lr': args.lr*0.2  , 'weight_decay': 0.},
            {'params': net_parameters_id['score_dsn_1-5.weight'], 'lr': args.lr*0.01 , 'weight_decay': args.weight_decay},
            {'params': net_parameters_id['score_dsn_1-5.bias']  , 'lr': args.lr*0.02 , 'weight_decay': 0.},
            {'params': net_parameters_id['score_final.weight']  , 'lr': args.lr*0.001, 'weight_decay': args.weight_decay},
            {'params': net_parameters_id['score_final.bias']    , 'lr': args.lr*0.002, 'weight_decay': 0.},
        ], lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
    scheduler = lr_scheduler.StepLR(optimizer, step_size=args.stepsize, gamma=args.gamma)



    # optimizer = torch.optim.Adam([
    #         {'params': net_parameters_id['conv1-4.weight']      , 'lr': args.lr*1    , 'weight_decay': args.weight_decay},
    #         {'params': net_parameters_id['conv1-4.bias']        , 'lr': args.lr*2    , 'weight_decay': 0.},
    #         {'params': net_parameters_id['conv5.weight']        , 'lr': args.lr*100  , 'weight_decay': args.weight_decay},
    #         {'params': net_parameters_id['conv5.bias']          , 'lr': args.lr*200  , 'weight_decay': 0.},
    #         {'params': net_parameters_id['conv_down_1-5.weight'], 'lr': args.lr*0.1  , 'weight_decay': args.weight_decay},
    #         {'params': net_parameters_id['conv_down_1-5.bias']  , 'lr': args.lr*0.2  , 'weight_decay': 0.},
    #         {'params': net_parameters_id['score_dsn_1-5.weight'], 'lr': args.lr*0.01 , 'weight_decay': args.weight_decay},
    #         {'params': net_parameters_id['score_dsn_1-5.bias']  , 'lr': args.lr*0.02 , 'weight_decay': 0.},
    #         {'params': net_parameters_id['score_final.weight']  , 'lr': args.lr*0.001, 'weight_decay': args.weight_decay},
    #         {'params': net_parameters_id['score_final.bias']    , 'lr': args.lr*0.002, 'weight_decay': 0.},
    #     ], lr=args.lr, betas=(0.9, 0.99), weight_decay=args.weight_decay)
    # scheduler = lr_scheduler.StepLR(optimizer, step_size=args.stepsize, gamma=args.gamma)

    # log
    log = Logger(join(TMP_DIR, '%s-%d-log.txt' %('sgd',args.lr)))
    sys.stdout = log

    train_loss = []
    train_loss_detail = []

    print("The start is epoch", args.start_epoch)
    print("The max is epoch",args.maxepoch)
    # for epoch in range(args.start_epoch, args.maxepoch):
    #     if epoch == 0:
    #         print("Performing initial testing...")
    #         multiscale_test(model, test_loader, epoch=epoch, test_list=test_list,
    #              save_dir = join(TMP_DIR, 'initial-testing-record'))
    #
    #     tr_avg_loss, tr_detail_loss = train(
    #         train_loader, model, optimizer, epoch,
    #         save_dir = join(TMP_DIR, 'epoch-%d-training-record' % epoch))
    #
    #     test(model, test_loader, epoch=epoch, test_list=test_list,
    #         save_dir = join(TMP_DIR, 'epoch-%d-testing-record-view' % epoch))
    #
    #     multiscale_test(model, test_loader, epoch=epoch, test_list=test_list,
    #         save_dir = join(TMP_DIR, 'epoch-%d-testing-record' % epoch))
    #
    #     log.flush() # write log
    #     # Save checkpoint
    #     save_file = os.path.join(TMP_DIR, 'checkpoint_epoch{}.pth'.format(epoch))
    #     save_checkpoint({
    #         'epoch': epoch,
    #         'state_dict': model.state_dict(),
    #         'optimizer': optimizer.state_dict()
    #                      }, filename=save_file)
    #     scheduler.step() # will adjust learning rate
        ## save train/val loss/accuracy, save every epoch in case of early stop
        # train_loss.append(tr_avg_loss)
        # train_loss_detail += tr_detail_loss


    #Testing the pretraind model over the test images
    checkpoint=torch.load("RCFcheckpoint_epoch12.pth")
    model.load_state_dict(checkpoint['state_dict'])
    epoch =0
    # print("Performing initial testing...")
    # multiscale_test(model, test_loader, epoch=epoch, test_list=test_list,
    #              save_dir = join(TMP_DIR, 'initial-testing-record'))
    # test(model, test_loader, epoch=epoch, test_list=test_list,
    #         save_dir = join(TMP_DIR, 'epoch-%d-testing-record-view' % epoch))
    #

    #########################
    #Test for our dataset!
    ########################
    vid_name = 'StudySpot'
    # # Read the video from specified path
    # cam = cv2.VideoCapture('data/Book_Occ3.MOV')
    # try:
    #     # creating a folder named data
    #     if not os.path.exists('data/'+vid_name+'/test'):
    #         os.makedirs('data/'+vid_name+'/test')
    #         # if not created then raise error
    # except OSError:
    #     print('Error: Creating directory of data')
    #     # frame
    # currentframe = 0
    # fil = open('data/'+vid_name+'/test.lst', "a+")
    # width_p = 481
    # height_p = 321
    #
    # while (True):
    #     # reading from frame
    #     ret, frame = cam.read()
    #     if ret:
    #         # if video is still left continue creating images
    #         name = 'data/' +vid_name+'/test/'+ str(currentframe) + '.jpg'
    #         print('Creating...' + name)
    #         # writing the extracted images
    #         frame = cv2.resize(frame, (width_p, height_p))
    #         cv2.imwrite(name, frame)
    #         fil.write('test/'+ str(currentframe) + '.jpg\n')
    #         # increasing counter so that it will
    #         # show how many frames are created
    #         currentframe += 1
    #
    #     else:
    #         break
    #
    # # Release all space and windows once done
    # cam.release()
    # fil.close()
    # cv2.destroyAllWindows()

    test_dataset = BSDS_RCFLoader(root='data/'+vid_name, split="test")
    print(test_dataset.filelist)

    test_loader = DataLoader(
        test_dataset, batch_size=args.batch_size,
        num_workers=8, drop_last=True, shuffle=False)

    with open('data/'+vid_name+'/test.lst', 'r') as f:
        test_list = f.readlines()
    test_list = [split(i.rstrip())[1] for i in test_list]
    assert len(test_list) == len(test_loader), "%d vs %d" % (len(test_list), len(test_loader))


    epoch = 0
    print("Performing testing...")
    # multiscale_test(model, test_loader, epoch=epoch, test_list=test_list,
    #              save_dir = join(TMP_DIR,vid_name,'initial-testing-record'))
    test(model, test_loader, epoch=epoch, test_list=test_list,
            save_dir = join(TMP_DIR,vid_name, 'epoch-%d-testing-record-view' % epoch))