Exemplo n.º 1
0
 def __init__(self):
     # model
     self.model = RCF()
     self.model.cuda()
     self.vgg_model_name = os.path.join(script_dir, "fast-rcnn-vgg16-pascal07-dagnn.mat")
     self.weights_pretraind = os.path.join(script_dir,"RCFcheckpoint_epoch12.pth")
     self.model.apply(weights_init)
     load_vgg16pretrain(self.model, self.vgg_model_name)
     self.checkpoint = torch.load(self.weights_pretraind)
     self.model.load_state_dict(self.checkpoint['state_dict'])
     self.model.eval()
Exemplo n.º 2
0
def main():
    args.cuda = True
    # dataset
    train_dataset = BSDS_RCFLoader(root=args.dataset, split="train")
    test_dataset = BSDS_RCFLoader(root=args.dataset, split="test")
    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              num_workers=8,
                              drop_last=True,
                              shuffle=True)
    test_loader = DataLoader(test_dataset,
                             batch_size=args.batch_size,
                             num_workers=8,
                             drop_last=True,
                             shuffle=False)
    with open('data/HED-BSDS_PASCAL/test.lst', 'r') as f:
        test_list = f.readlines()
    test_list = [split(i.rstrip())[1] for i in test_list]
    assert len(test_list) == len(test_loader), "%d vs %d" % (len(test_list),
                                                             len(test_loader))

    # model
    model = RCF()
    model.cuda()
    model.apply(weights_init)
    load_vgg16pretrain(model)
    if args.resume:
        if isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}'".format(args.resume))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    #tune lr
    net_parameters_id = {}
    net = model
    for pname, p in net.named_parameters():
        if pname in [
                'conv1_1.weight', 'conv1_2.weight', 'conv2_1.weight',
                'conv2_2.weight', 'conv3_1.weight', 'conv3_2.weight',
                'conv3_3.weight', 'conv4_1.weight', 'conv4_2.weight',
                'conv4_3.weight'
        ]:
            print(pname, 'lr:1 de:1')
            if 'conv1-4.weight' not in net_parameters_id:
                net_parameters_id['conv1-4.weight'] = []
            net_parameters_id['conv1-4.weight'].append(p)
        elif pname in [
                'conv1_1.bias', 'conv1_2.bias', 'conv2_1.bias', 'conv2_2.bias',
                'conv3_1.bias', 'conv3_2.bias', 'conv3_3.bias', 'conv4_1.bias',
                'conv4_2.bias', 'conv4_3.bias'
        ]:
            print(pname, 'lr:2 de:0')
            if 'conv1-4.bias' not in net_parameters_id:
                net_parameters_id['conv1-4.bias'] = []
            net_parameters_id['conv1-4.bias'].append(p)
        elif pname in ['conv5_1.weight', 'conv5_2.weight', 'conv5_3.weight']:
            print(pname, 'lr:100 de:1')
            if 'conv5.weight' not in net_parameters_id:
                net_parameters_id['conv5.weight'] = []
            net_parameters_id['conv5.weight'].append(p)
        elif pname in ['conv5_1.bias', 'conv5_2.bias', 'conv5_3.bias']:
            print(pname, 'lr:200 de:0')
            if 'conv5.bias' not in net_parameters_id:
                net_parameters_id['conv5.bias'] = []
            net_parameters_id['conv5.bias'].append(p)
        elif pname in [
                'conv1_1_down.weight', 'conv1_2_down.weight',
                'conv2_1_down.weight', 'conv2_2_down.weight',
                'conv3_1_down.weight', 'conv3_2_down.weight',
                'conv3_3_down.weight', 'conv4_1_down.weight',
                'conv4_2_down.weight', 'conv4_3_down.weight',
                'conv5_1_down.weight', 'conv5_2_down.weight',
                'conv5_3_down.weight'
        ]:
            print(pname, 'lr:0.1 de:1')
            if 'conv_down_1-5.weight' not in net_parameters_id:
                net_parameters_id['conv_down_1-5.weight'] = []
            net_parameters_id['conv_down_1-5.weight'].append(p)
        elif pname in [
                'conv1_1_down.bias', 'conv1_2_down.bias', 'conv2_1_down.bias',
                'conv2_2_down.bias', 'conv3_1_down.bias', 'conv3_2_down.bias',
                'conv3_3_down.bias', 'conv4_1_down.bias', 'conv4_2_down.bias',
                'conv4_3_down.bias', 'conv5_1_down.bias', 'conv5_2_down.bias',
                'conv5_3_down.bias'
        ]:
            print(pname, 'lr:0.2 de:0')
            if 'conv_down_1-5.bias' not in net_parameters_id:
                net_parameters_id['conv_down_1-5.bias'] = []
            net_parameters_id['conv_down_1-5.bias'].append(p)
        elif pname in [
                'score_dsn1.weight', 'score_dsn2.weight', 'score_dsn3.weight',
                'score_dsn4.weight', 'score_dsn5.weight'
        ]:
            print(pname, 'lr:0.01 de:1')
            if 'score_dsn_1-5.weight' not in net_parameters_id:
                net_parameters_id['score_dsn_1-5.weight'] = []
            net_parameters_id['score_dsn_1-5.weight'].append(p)
        elif pname in [
                'score_dsn1.bias', 'score_dsn2.bias', 'score_dsn3.bias',
                'score_dsn4.bias', 'score_dsn5.bias'
        ]:
            print(pname, 'lr:0.02 de:0')
            if 'score_dsn_1-5.bias' not in net_parameters_id:
                net_parameters_id['score_dsn_1-5.bias'] = []
            net_parameters_id['score_dsn_1-5.bias'].append(p)
        elif pname in ['score_final.weight']:
            print(pname, 'lr:0.001 de:1')
            if 'score_final.weight' not in net_parameters_id:
                net_parameters_id['score_final.weight'] = []
            net_parameters_id['score_final.weight'].append(p)
        elif pname in ['score_final.bias']:
            print(pname, 'lr:0.002 de:0')
            if 'score_final.bias' not in net_parameters_id:
                net_parameters_id['score_final.bias'] = []
            net_parameters_id['score_final.bias'].append(p)

    optimizer = torch.optim.SGD([
        {
            'params': net_parameters_id['conv1-4.weight'],
            'lr': args.lr * 1,
            'weight_decay': args.weight_decay
        },
        {
            'params': net_parameters_id['conv1-4.bias'],
            'lr': args.lr * 2,
            'weight_decay': 0.
        },
        {
            'params': net_parameters_id['conv5.weight'],
            'lr': args.lr * 100,
            'weight_decay': args.weight_decay
        },
        {
            'params': net_parameters_id['conv5.bias'],
            'lr': args.lr * 200,
            'weight_decay': 0.
        },
        {
            'params': net_parameters_id['conv_down_1-5.weight'],
            'lr': args.lr * 0.1,
            'weight_decay': args.weight_decay
        },
        {
            'params': net_parameters_id['conv_down_1-5.bias'],
            'lr': args.lr * 0.2,
            'weight_decay': 0.
        },
        {
            'params': net_parameters_id['score_dsn_1-5.weight'],
            'lr': args.lr * 0.01,
            'weight_decay': args.weight_decay
        },
        {
            'params': net_parameters_id['score_dsn_1-5.bias'],
            'lr': args.lr * 0.02,
            'weight_decay': 0.
        },
        {
            'params': net_parameters_id['score_final.weight'],
            'lr': args.lr * 0.001,
            'weight_decay': args.weight_decay
        },
        {
            'params': net_parameters_id['score_final.bias'],
            'lr': args.lr * 0.002,
            'weight_decay': 0.
        },
    ],
                                lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    scheduler = lr_scheduler.StepLR(optimizer,
                                    step_size=args.stepsize,
                                    gamma=args.gamma)

    # optimizer = torch.optim.Adam([
    #         {'params': net_parameters_id['conv1-4.weight']      , 'lr': args.lr*1    , 'weight_decay': args.weight_decay},
    #         {'params': net_parameters_id['conv1-4.bias']        , 'lr': args.lr*2    , 'weight_decay': 0.},
    #         {'params': net_parameters_id['conv5.weight']        , 'lr': args.lr*100  , 'weight_decay': args.weight_decay},
    #         {'params': net_parameters_id['conv5.bias']          , 'lr': args.lr*200  , 'weight_decay': 0.},
    #         {'params': net_parameters_id['conv_down_1-5.weight'], 'lr': args.lr*0.1  , 'weight_decay': args.weight_decay},
    #         {'params': net_parameters_id['conv_down_1-5.bias']  , 'lr': args.lr*0.2  , 'weight_decay': 0.},
    #         {'params': net_parameters_id['score_dsn_1-5.weight'], 'lr': args.lr*0.01 , 'weight_decay': args.weight_decay},
    #         {'params': net_parameters_id['score_dsn_1-5.bias']  , 'lr': args.lr*0.02 , 'weight_decay': 0.},
    #         {'params': net_parameters_id['score_final.weight']  , 'lr': args.lr*0.001, 'weight_decay': args.weight_decay},
    #         {'params': net_parameters_id['score_final.bias']    , 'lr': args.lr*0.002, 'weight_decay': 0.},
    #     ], lr=args.lr, betas=(0.9, 0.99), weight_decay=args.weight_decay)
    # scheduler = lr_scheduler.StepLR(optimizer, step_size=args.stepsize, gamma=args.gamma)

    # log
    log = Logger(join(TMP_DIR, '%s-%d-log.txt' % ('sgd', args.lr)))
    sys.stdout = log

    train_loss = []
    train_loss_detail = []
    for epoch in range(args.start_epoch, args.maxepoch):
        if epoch == 0:
            print("Performing initial testing...")
            multiscale_test(model,
                            test_loader,
                            epoch=epoch,
                            test_list=test_list,
                            save_dir=join(TMP_DIR, 'initial-testing-record'))

        tr_avg_loss, tr_detail_loss = train(
            train_loader,
            model,
            optimizer,
            epoch,
            save_dir=join(TMP_DIR, 'epoch-%d-training-record' % epoch))
        test(model,
             test_loader,
             epoch=epoch,
             test_list=test_list,
             save_dir=join(TMP_DIR, 'epoch-%d-testing-record-view' % epoch))
        multiscale_test(model,
                        test_loader,
                        epoch=epoch,
                        test_list=test_list,
                        save_dir=join(TMP_DIR,
                                      'epoch-%d-testing-record' % epoch))
        log.flush()  # write log
        # Save checkpoint
        save_file = os.path.join(TMP_DIR,
                                 'checkpoint_epoch{}.pth'.format(epoch))
        save_checkpoint(
            {
                'epoch': epoch,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict()
            },
            filename=save_file)
        scheduler.step()  # will adjust learning rate
        # save train/val loss/accuracy, save every epoch in case of early stop
        train_loss.append(tr_avg_loss)
        train_loss_detail += tr_detail_loss
Exemplo n.º 3
0
def main():
    args.cuda = True
    # dataset
    train_dataset = BSDS_RCFLoader(root=args.dataset, split="train")
    test_dataset = BSDS_RCFLoader(root=args.dataset, split="test")
    train_loader = DataLoader(
        train_dataset, batch_size=args.batch_size,
        num_workers=8, drop_last=True,shuffle=True)
    test_loader = DataLoader(
        test_dataset, batch_size=args.batch_size,
        num_workers=8, drop_last=True,shuffle=False)


    with open('data/HED-BSDS_PASCAL/test.lst', 'r') as f:
    # with open('data/HED-BSDS/test.lst', 'r') as f:
        test_list = f.readlines()
    test_list = [split(i.rstrip())[1] for i in test_list]
    assert len(test_list) == len(test_loader), "%d vs %d" % (len(test_list), len(test_loader))

    # model
    model = RCF()
    model.cuda()
    model.apply(weights_init)
    load_vgg16pretrain(model,vgg_model_name)

    if args.resume:
        if isfile(args.resume): 
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}'"
                  .format(args.resume))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    
    #tune lr
    net_parameters_id = {}
    net = model
    for pname, p in net.named_parameters():
        if pname in ['conv1_1.weight','conv1_2.weight',
                     'conv2_1.weight','conv2_2.weight',
                     'conv3_1.weight','conv3_2.weight','conv3_3.weight',
                     'conv4_1.weight','conv4_2.weight','conv4_3.weight']:
            print(pname, 'lr:1 de:1')
            if 'conv1-4.weight' not in net_parameters_id:
                net_parameters_id['conv1-4.weight'] = []
            net_parameters_id['conv1-4.weight'].append(p)
        elif pname in ['conv1_1.bias','conv1_2.bias',
                       'conv2_1.bias','conv2_2.bias',
                       'conv3_1.bias','conv3_2.bias','conv3_3.bias',
                       'conv4_1.bias','conv4_2.bias','conv4_3.bias']:
            print(pname, 'lr:2 de:0')
            if 'conv1-4.bias' not in net_parameters_id:
                net_parameters_id['conv1-4.bias'] = []
            net_parameters_id['conv1-4.bias'].append(p)
        elif pname in ['conv5_1.weight','conv5_2.weight','conv5_3.weight']:
            print(pname, 'lr:100 de:1')
            if 'conv5.weight' not in net_parameters_id:
                net_parameters_id['conv5.weight'] = []
            net_parameters_id['conv5.weight'].append(p)
        elif pname in ['conv5_1.bias','conv5_2.bias','conv5_3.bias'] :
            print(pname, 'lr:200 de:0')
            if 'conv5.bias' not in net_parameters_id:
                net_parameters_id['conv5.bias'] = []
            net_parameters_id['conv5.bias'].append(p)
        elif pname in ['conv1_1_down.weight','conv1_2_down.weight',
                       'conv2_1_down.weight','conv2_2_down.weight',
                       'conv3_1_down.weight','conv3_2_down.weight','conv3_3_down.weight',
                       'conv4_1_down.weight','conv4_2_down.weight','conv4_3_down.weight',
                       'conv5_1_down.weight','conv5_2_down.weight','conv5_3_down.weight']:
            print(pname, 'lr:0.1 de:1')
            if 'conv_down_1-5.weight' not in net_parameters_id:
                net_parameters_id['conv_down_1-5.weight'] = []
            net_parameters_id['conv_down_1-5.weight'].append(p)
        elif pname in ['conv1_1_down.bias','conv1_2_down.bias',
                       'conv2_1_down.bias','conv2_2_down.bias',
                       'conv3_1_down.bias','conv3_2_down.bias','conv3_3_down.bias',
                       'conv4_1_down.bias','conv4_2_down.bias','conv4_3_down.bias',
                       'conv5_1_down.bias','conv5_2_down.bias','conv5_3_down.bias']:
            print(pname, 'lr:0.2 de:0')
            if 'conv_down_1-5.bias' not in net_parameters_id:
                net_parameters_id['conv_down_1-5.bias'] = []
            net_parameters_id['conv_down_1-5.bias'].append(p)
        elif pname in ['score_dsn1.weight','score_dsn2.weight','score_dsn3.weight',
                       'score_dsn4.weight','score_dsn5.weight']:
            print(pname, 'lr:0.01 de:1')
            if 'score_dsn_1-5.weight' not in net_parameters_id:
                net_parameters_id['score_dsn_1-5.weight'] = []
            net_parameters_id['score_dsn_1-5.weight'].append(p)
        elif pname in ['score_dsn1.bias','score_dsn2.bias','score_dsn3.bias',
                       'score_dsn4.bias','score_dsn5.bias']:
            print(pname, 'lr:0.02 de:0')
            if 'score_dsn_1-5.bias' not in net_parameters_id:
                net_parameters_id['score_dsn_1-5.bias'] = []
            net_parameters_id['score_dsn_1-5.bias'].append(p)
        elif pname in ['score_final.weight']:
            print(pname, 'lr:0.001 de:1')
            if 'score_final.weight' not in net_parameters_id:
                net_parameters_id['score_final.weight'] = []
            net_parameters_id['score_final.weight'].append(p)
        elif pname in ['score_final.bias']:
            print(pname, 'lr:0.002 de:0')
            if 'score_final.bias' not in net_parameters_id:
                net_parameters_id['score_final.bias'] = []
            net_parameters_id['score_final.bias'].append(p)

    optimizer = torch.optim.SGD([
            {'params': net_parameters_id['conv1-4.weight']      , 'lr': args.lr*1    , 'weight_decay': args.weight_decay},
            {'params': net_parameters_id['conv1-4.bias']        , 'lr': args.lr*2    , 'weight_decay': 0.},
            {'params': net_parameters_id['conv5.weight']        , 'lr': args.lr*100  , 'weight_decay': args.weight_decay},
            {'params': net_parameters_id['conv5.bias']          , 'lr': args.lr*200  , 'weight_decay': 0.},
            {'params': net_parameters_id['conv_down_1-5.weight'], 'lr': args.lr*0.1  , 'weight_decay': args.weight_decay},
            {'params': net_parameters_id['conv_down_1-5.bias']  , 'lr': args.lr*0.2  , 'weight_decay': 0.},
            {'params': net_parameters_id['score_dsn_1-5.weight'], 'lr': args.lr*0.01 , 'weight_decay': args.weight_decay},
            {'params': net_parameters_id['score_dsn_1-5.bias']  , 'lr': args.lr*0.02 , 'weight_decay': 0.},
            {'params': net_parameters_id['score_final.weight']  , 'lr': args.lr*0.001, 'weight_decay': args.weight_decay},
            {'params': net_parameters_id['score_final.bias']    , 'lr': args.lr*0.002, 'weight_decay': 0.},
        ], lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
    scheduler = lr_scheduler.StepLR(optimizer, step_size=args.stepsize, gamma=args.gamma)



    # optimizer = torch.optim.Adam([
    #         {'params': net_parameters_id['conv1-4.weight']      , 'lr': args.lr*1    , 'weight_decay': args.weight_decay},
    #         {'params': net_parameters_id['conv1-4.bias']        , 'lr': args.lr*2    , 'weight_decay': 0.},
    #         {'params': net_parameters_id['conv5.weight']        , 'lr': args.lr*100  , 'weight_decay': args.weight_decay},
    #         {'params': net_parameters_id['conv5.bias']          , 'lr': args.lr*200  , 'weight_decay': 0.},
    #         {'params': net_parameters_id['conv_down_1-5.weight'], 'lr': args.lr*0.1  , 'weight_decay': args.weight_decay},
    #         {'params': net_parameters_id['conv_down_1-5.bias']  , 'lr': args.lr*0.2  , 'weight_decay': 0.},
    #         {'params': net_parameters_id['score_dsn_1-5.weight'], 'lr': args.lr*0.01 , 'weight_decay': args.weight_decay},
    #         {'params': net_parameters_id['score_dsn_1-5.bias']  , 'lr': args.lr*0.02 , 'weight_decay': 0.},
    #         {'params': net_parameters_id['score_final.weight']  , 'lr': args.lr*0.001, 'weight_decay': args.weight_decay},
    #         {'params': net_parameters_id['score_final.bias']    , 'lr': args.lr*0.002, 'weight_decay': 0.},
    #     ], lr=args.lr, betas=(0.9, 0.99), weight_decay=args.weight_decay)
    # scheduler = lr_scheduler.StepLR(optimizer, step_size=args.stepsize, gamma=args.gamma)

    # log
    log = Logger(join(TMP_DIR, '%s-%d-log.txt' %('sgd',args.lr)))
    sys.stdout = log

    train_loss = []
    train_loss_detail = []

    print("The start is epoch", args.start_epoch)
    print("The max is epoch",args.maxepoch)
    # for epoch in range(args.start_epoch, args.maxepoch):
    #     if epoch == 0:
    #         print("Performing initial testing...")
    #         multiscale_test(model, test_loader, epoch=epoch, test_list=test_list,
    #              save_dir = join(TMP_DIR, 'initial-testing-record'))
    #
    #     tr_avg_loss, tr_detail_loss = train(
    #         train_loader, model, optimizer, epoch,
    #         save_dir = join(TMP_DIR, 'epoch-%d-training-record' % epoch))
    #
    #     test(model, test_loader, epoch=epoch, test_list=test_list,
    #         save_dir = join(TMP_DIR, 'epoch-%d-testing-record-view' % epoch))
    #
    #     multiscale_test(model, test_loader, epoch=epoch, test_list=test_list,
    #         save_dir = join(TMP_DIR, 'epoch-%d-testing-record' % epoch))
    #
    #     log.flush() # write log
    #     # Save checkpoint
    #     save_file = os.path.join(TMP_DIR, 'checkpoint_epoch{}.pth'.format(epoch))
    #     save_checkpoint({
    #         'epoch': epoch,
    #         'state_dict': model.state_dict(),
    #         'optimizer': optimizer.state_dict()
    #                      }, filename=save_file)
    #     scheduler.step() # will adjust learning rate
        ## save train/val loss/accuracy, save every epoch in case of early stop
        # train_loss.append(tr_avg_loss)
        # train_loss_detail += tr_detail_loss


    #Testing the pretraind model over the test images
    checkpoint=torch.load("RCFcheckpoint_epoch12.pth")
    model.load_state_dict(checkpoint['state_dict'])
    epoch =0
    # print("Performing initial testing...")
    # multiscale_test(model, test_loader, epoch=epoch, test_list=test_list,
    #              save_dir = join(TMP_DIR, 'initial-testing-record'))
    # test(model, test_loader, epoch=epoch, test_list=test_list,
    #         save_dir = join(TMP_DIR, 'epoch-%d-testing-record-view' % epoch))
    #

    #########################
    #Test for our dataset!
    ########################
    vid_name = 'StudySpot'
    # # Read the video from specified path
    # cam = cv2.VideoCapture('data/Book_Occ3.MOV')
    # try:
    #     # creating a folder named data
    #     if not os.path.exists('data/'+vid_name+'/test'):
    #         os.makedirs('data/'+vid_name+'/test')
    #         # if not created then raise error
    # except OSError:
    #     print('Error: Creating directory of data')
    #     # frame
    # currentframe = 0
    # fil = open('data/'+vid_name+'/test.lst', "a+")
    # width_p = 481
    # height_p = 321
    #
    # while (True):
    #     # reading from frame
    #     ret, frame = cam.read()
    #     if ret:
    #         # if video is still left continue creating images
    #         name = 'data/' +vid_name+'/test/'+ str(currentframe) + '.jpg'
    #         print('Creating...' + name)
    #         # writing the extracted images
    #         frame = cv2.resize(frame, (width_p, height_p))
    #         cv2.imwrite(name, frame)
    #         fil.write('test/'+ str(currentframe) + '.jpg\n')
    #         # increasing counter so that it will
    #         # show how many frames are created
    #         currentframe += 1
    #
    #     else:
    #         break
    #
    # # Release all space and windows once done
    # cam.release()
    # fil.close()
    # cv2.destroyAllWindows()

    test_dataset = BSDS_RCFLoader(root='data/'+vid_name, split="test")
    print(test_dataset.filelist)

    test_loader = DataLoader(
        test_dataset, batch_size=args.batch_size,
        num_workers=8, drop_last=True, shuffle=False)

    with open('data/'+vid_name+'/test.lst', 'r') as f:
        test_list = f.readlines()
    test_list = [split(i.rstrip())[1] for i in test_list]
    assert len(test_list) == len(test_loader), "%d vs %d" % (len(test_list), len(test_loader))


    epoch = 0
    print("Performing testing...")
    # multiscale_test(model, test_loader, epoch=epoch, test_list=test_list,
    #              save_dir = join(TMP_DIR,vid_name,'initial-testing-record'))
    test(model, test_loader, epoch=epoch, test_list=test_list,
            save_dir = join(TMP_DIR,vid_name, 'epoch-%d-testing-record-view' % epoch))
Exemplo n.º 4
0
class RCF_BoundaryOcclusionBoundaryDetector():

    def __init__(self):
        # model
        self.model = RCF()
        self.model.cuda()
        self.vgg_model_name = os.path.join(script_dir, "fast-rcnn-vgg16-pascal07-dagnn.mat")
        self.weights_pretraind = os.path.join(script_dir,"RCFcheckpoint_epoch12.pth")
        self.model.apply(weights_init)
        load_vgg16pretrain(self.model, self.vgg_model_name)
        self.checkpoint = torch.load(self.weights_pretraind)
        self.model.load_state_dict(self.checkpoint['state_dict'])
        self.model.eval()



    def refinement_gradVConturMulti(self,arr_imges):

        NumImges = 6
        #First Image is the gradient
        grad_img = torch.squeeze(arr_imges[0].detach()).cpu().numpy()
        # print(np.array(255 * grad_img).astype(int))


        # print(np.max(grad_img))
        if SHOW_RESULTS:
            plt.figure()
            plt.axis('off')
            plt.imshow(grad_img*255)
            plt.savefig("grad_img_th.png")
            plt.show()


        ## LSD
        # print((grad_img*255).astype('uint8'))

        grad_img_lsd =np.zeros(grad_img.shape)

        # img_arrays is array of gray images !! or gray image with (1,Xsize,Ysize) !
        # print("Start :  Get LSD for RCF Results")

        # Create default parametrization LSD
        lsd = cv2.createLineSegmentDetector(0)
        kernel2 = np.array([[1, 2, 1], [0, 0, 0], [-1, -2, -1]], dtype=np.float32)
        kernel3 = np.array([[1, 0, -1], [2, 0, -2], [1, 0, -1]], dtype=np.float32)

        img_rcf_LSD = (grad_img*255).astype('uint8')
        # img_CannySobel_LSD = result.astype(np.uint8).copy()
        lines_img_rcf_LSD = lsd.detect(img_rcf_LSD)[0]  # Position 0 of the returned tuple are the detected lines
        # print(lines_img_rcf_LSD)
        segment_img_rcf_LSD = np.zeros_like(img_rcf_LSD)
        # lines_img_CannySobel_LSD = lsd.detect( img_CannySobel_LSD)[0]
        # segment_img_CannySobel_LSD = np.zeros_like( img_CannySobel_LSD)

        for dline in lines_img_rcf_LSD:
            x0 = int(round(dline[0][0]))
            y0 = int(round(dline[0][1]))
            x1 = int(round(dline[0][2]))
            y1 = int(round(dline[0][3]))
            cv2.line(grad_img_lsd, (x0, y0), (x1, y1), 255, 1, cv2.LINE_AA)

        if SHOW_RESULTS:
            plt.figure()
            plt.axis('off')
            plt.imshow(grad_img_lsd)
            plt.savefig("grad_img_lsd.png")
            plt.show()
            print(grad_img_lsd)

            plt.figure()
            plt.axis('off')
            plt.imshow(grad_img_lsd*(grad_img))
            plt.savefig("grad_img_lsd_(grad_img).png")
            plt.show()


        # cv2.imshow("SImag",grad_img_lsd)
        # cv2.waitKey(0)
        # refine_grad = (grad_img*255).astype('uint8')
        # refine_lsd = grad_img_lsd.copy()
        # refine_grad_and_lsd = grad_img_lsd*(grad_img).copy()
        # im_res = np.zeros(grad_img.shape)
        # alphs = [0.6,0.1,0.1,0.1,0.1,0.1]
        # im_res = im_res + alphs[0]*refine_grad

        # for i in range(1, 6):
        #     img = torch.squeeze(arr_imges[i].detach()).cpu().numpy()
        #     # print(alphs[i])
        #     #
        #     # img = (img * 255).astype('uint8')
        #     # img_lsd = np.zeros(img.shape)
        #     # lines_img_rcf_LSD = lsd.detect(img)[0]
        #     # for dline in lines_img_rcf_LSD:
        #     #     x0 = int(round(dline[0][0]))
        #     #     y0 = int(round(dline[0][1]))
        #     #     x1 = int(round(dline[0][2]))
        #     #     y1 = int(round(dline[0][3]))
        #     #     cv2.line(img_lsd, (x0, y0), (x1, y1), 255, 1, cv2.LINE_AA)
        #     #
        #     #
        #     # cv2.imshow("SImag",img_lsd)
        #     # cv2.waitKey(0)
        #
        #     im_res =im_res + np.multiply(255,alphs[i]*img).astype(int)
        #     # im_res = (255 * (im_res - np.min(im_res)) / (np.max(im_res) - np.min(im_res))).astype('uint8')
        #     print(im_res)
        #     # print(np.max(im_res))
        #     # Scaling done : divide by max and multiply with 255
        #
        #     # im_res_scl = (255 * (im_res-np.min(im_res))/(np.max(im_res)- np.min(im_res))).astype('uint8')
        #
        #     if SHOW_RESULTS:
        #         plt.figure()
        #         plt.imshow(im_res)
        #         plt.figure()
        #
        #
        #
        #
        # im_res_lsd = np.zeros(grad_img.shape).astype('uint8')
        # lines_img_rcf_LSD = lsd.detect(im_res.astype('uint8'))[0]
        #
        # for dline in lines_img_rcf_LSD:
        #     x0 = int(round(dline[0][0]))
        #     y0 = int(round(dline[0][1]))
        #     x1 = int(round(dline[0][2]))
        #     y1 = int(round(dline[0][3]))
        #     cv2.line(im_res_lsd, (x0, y0), (x1, y1), 255, 1, cv2.LINE_AA)
        #
        #



        # corrleation refine image
        alphas = np.array([0.3,0,0,0,0,0.7])

        im_res = 255*alphas[0]*grad_img
        # print(im_res)
        for i in range(1,6):

            img = torch.squeeze(arr_imges[i].detach()).cpu().numpy()
            # im_res = (255 * (im_res - np.min(im_res)) / (np.max(im_res) - np.min(im_res))).astype('uint8')
            im_res =  255*img*alphas[i] + im_res

            # print(im_res)


            # print(np.max(im_res))
            # Scaling done : divide by max and multiply with 255

            if SHOW_RESULTS:

                plt.figure()
                plt.imshow(im_res)
                plt.show()



       


        return im_res,grad_img_lsd




    def boundary_detection(self,image):

        #build testloader
        if not isdir(os.path.join(script_dir,"test")):
            os.makedirs(os.path.join(script_dir,"test"))
        width_p = 481
        height_p = 321
        w_or,h_or = image.shape[0:2]

        image = cv2.resize(image, (width_p, height_p))
        cv2.imwrite(os.path.join(script_dir,"test","tmp.jpg"),image)
        fil = open(os.path.join(script_dir,"test.lst"), "a+")
        fil.write('test/tmp.jpg\n')

        test_dataset = BSDS_RCFLoader(root=os.path.join(script_dir), split="test")
        # print(test_dataset.filelist)
        test_loader = DataLoader(
            test_dataset, batch_size=1,
            num_workers=8, drop_last=True, shuffle=False)


        for idx, image in enumerate(test_loader):
            image = image.cuda()
            _, _, H, W = image.shape
            results = self.model(image)
            refine_img,lsd_grd = self.refinement_gradVConturMulti(results)
            res_fusion = 255*torch.squeeze(results[-1].detach()).cpu().numpy()
            graident = 255*torch.squeeze(results[0].detach()).cpu().numpy()
            results_all = torch.zeros((len(results), 1, H, W))
            if SHOW_RESULTS:
                plt.imshow(res_fusion)
                plt.show()

            # cv2.imshow('Fusion', res_fusion.astype('uint8'))
            # cv2.waitKey(0)
            # cv2.imshow('Graident',torch.squeeze(results[0].detach()).cpu().numpy())
            # cv2.waitKey(0)
            # cv2.imshow('refine_img', refine_img.astype('uint8'))
            # cv2.waitKey(0)
            # cv2.imshow('lsd_grd', lsd_grd)
            # cv2.waitKey(0)


        open(os.path.join(script_dir,"test.lst"), 'w').close()

        return res_fusion,graident,refine_img,lsd_grd