예제 #1
0
 def __init__(self):
     # model
     self.model = RCF()
     self.model.cuda()
     self.vgg_model_name = os.path.join(script_dir, "fast-rcnn-vgg16-pascal07-dagnn.mat")
     self.weights_pretraind = os.path.join(script_dir,"RCFcheckpoint_epoch12.pth")
     self.model.apply(weights_init)
     load_vgg16pretrain(self.model, self.vgg_model_name)
     self.checkpoint = torch.load(self.weights_pretraind)
     self.model.load_state_dict(self.checkpoint['state_dict'])
     self.model.eval()
예제 #2
0
def main():
    args.cuda = True
    # dataset
    train_dataset = BSDSLoader(root=args.dataset, split="train")
    test_dataset = BSDSLoader(root=args.dataset, split="test")
    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              num_workers=8,
                              drop_last=True,
                              shuffle=True)
    test_loader = DataLoader(test_dataset,
                             batch_size=args.batch_size,
                             num_workers=8,
                             drop_last=True,
                             shuffle=False)
    with open('data/HED-BSDS/test.lst', 'r') as f:
        test_list = f.readlines()
    test_list = [split(i.rstrip())[1] for i in test_list]
    assert len(test_list) == len(test_loader), "%d vs %d" % (len(test_list),
                                                             len(test_loader))

    # model
    model = HED()
    model.cuda()
    model.apply(weights_init)
    load_vgg16pretrain(model)

    if args.resume:
        if isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}'".format(args.resume))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    #tune lr
    net_parameters_id = {}
    net = model
    for pname, p in net.named_parameters():
        if pname in [
                'conv1_1.weight', 'conv1_2.weight', 'conv2_1.weight',
                'conv2_2.weight', 'conv3_1.weight', 'conv3_2.weight',
                'conv3_3.weight', 'conv4_1.weight', 'conv4_2.weight',
                'conv4_3.weight'
        ]:
            #print(pname, 'lr:1 de:1')
            if 'conv1-4.weight' not in net_parameters_id:
                net_parameters_id['conv1-4.weight'] = []
            net_parameters_id['conv1-4.weight'].append(p)
        elif pname in [
                'conv1_1.bias', 'conv1_2.bias', 'conv2_1.bias', 'conv2_2.bias',
                'conv3_1.bias', 'conv3_2.bias', 'conv3_3.bias', 'conv4_1.bias',
                'conv4_2.bias', 'conv4_3.bias'
        ]:
            #print(pname, 'lr:2 de:0')
            if 'conv1-4.bias' not in net_parameters_id:
                net_parameters_id['conv1-4.bias'] = []
            net_parameters_id['conv1-4.bias'].append(p)
        elif pname in ['conv5_1.weight', 'conv5_2.weight', 'conv5_3.weight']:
            #print(pname, 'lr:100 de:1')
            if 'conv5.weight' not in net_parameters_id:
                net_parameters_id['conv5.weight'] = []
            net_parameters_id['conv5.weight'].append(p)
        elif pname in ['conv5_1.bias', 'conv5_2.bias', 'conv5_3.bias']:
            #print(pname, 'lr:200 de:0')
            if 'conv5.bias' not in net_parameters_id:
                net_parameters_id['conv5.bias'] = []
            net_parameters_id['conv5.bias'].append(p)

        elif pname in [
                'score_dsn1.weight', 'score_dsn2.weight', 'score_dsn3.weight',
                'score_dsn4.weight', 'score_dsn5.weight'
        ]:
            #print(pname, 'lr:0.01 de:1')
            if 'score_dsn_1-5.weight' not in net_parameters_id:
                net_parameters_id['score_dsn_1-5.weight'] = []
            net_parameters_id['score_dsn_1-5.weight'].append(p)
        elif pname in [
                'score_dsn1.bias', 'score_dsn2.bias', 'score_dsn3.bias',
                'score_dsn4.bias', 'score_dsn5.bias'
        ]:
            #print(pname, 'lr:0.02 de:0')
            if 'score_dsn_1-5.bias' not in net_parameters_id:
                net_parameters_id['score_dsn_1-5.bias'] = []
            net_parameters_id['score_dsn_1-5.bias'].append(p)
        elif pname in ['score_final.weight']:
            #print(pname, 'lr:0.001 de:1')
            if 'score_final.weight' not in net_parameters_id:
                net_parameters_id['score_final.weight'] = []
            net_parameters_id['score_final.weight'].append(p)
        elif pname in ['score_final.bias']:
            #print(pname, 'lr:0.002 de:0')
            if 'score_final.bias' not in net_parameters_id:
                net_parameters_id['score_final.bias'] = []
            net_parameters_id['score_final.bias'].append(p)

    optimizer = torch.optim.SGD([
        {
            'params': net_parameters_id['conv1-4.weight'],
            'lr': args.lr * 1,
            'weight_decay': args.weight_decay
        },
        {
            'params': net_parameters_id['conv1-4.bias'],
            'lr': args.lr * 2,
            'weight_decay': 0.
        },
        {
            'params': net_parameters_id['conv5.weight'],
            'lr': args.lr * 100,
            'weight_decay': args.weight_decay
        },
        {
            'params': net_parameters_id['conv5.bias'],
            'lr': args.lr * 200,
            'weight_decay': 0.
        },
        {
            'params': net_parameters_id['score_dsn_1-5.weight'],
            'lr': args.lr * 0.01,
            'weight_decay': args.weight_decay
        },
        {
            'params': net_parameters_id['score_dsn_1-5.bias'],
            'lr': args.lr * 0.02,
            'weight_decay': 0.
        },
        {
            'params': net_parameters_id['score_final.weight'],
            'lr': args.lr * 0.001,
            'weight_decay': args.weight_decay
        },
        {
            'params': net_parameters_id['score_final.bias'],
            'lr': args.lr * 0.002,
            'weight_decay': 0.
        },
    ],
                                lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    scheduler = lr_scheduler.StepLR(optimizer,
                                    step_size=args.stepsize,
                                    gamma=args.gamma)

    # log
    log = Logger(join(TMP_DIR, '%s-%d-log.txt' % ('Adam', args.lr)))
    sys.stdout = log

    train_loss = []
    train_loss_detail = []
    for epoch in range(args.start_epoch, args.maxepoch):
        #if epoch == 0:
        #   print("Performing initial testing...")
        #  test(model, test_loader, epoch=epoch, test_list=test_list,
        #      save_dir = join(TMP_DIR, 'initial-testing-record'))

        tr_avg_loss, tr_detail_loss = train(
            train_loader,
            model,
            optimizer,
            epoch,
            save_dir=join(TMP_DIR, 'epoch-%d-training-record' % epoch))
        test(model,
             test_loader,
             epoch=epoch,
             test_list=test_list,
             save_dir=join(TMP_DIR, 'epoch-%d-testing-record' % epoch))
        log.flush()  # write log
        # Save checkpoint
        save_file = os.path.join(TMP_DIR,
                                 'checkpoint_epoch{}.pth'.format(epoch))
        save_checkpoint(
            {
                'epoch': epoch,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict()
            },
            filename=save_file)
        scheduler.step()  # will adjust learning rate
        # save train/val loss/accuracy, save every epoch in case of early stop
        train_loss.append(tr_avg_loss)
        train_loss_detail += tr_detail_loss
예제 #3
0
def main():
    args.cuda = True
    # dataset
    train_dataset = BSDS_RCFLoader(root=args.dataset, split="train")
    test_dataset = BSDS_RCFLoader(root=args.dataset + "/HED-BSDS",
                                  split="test")
    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              num_workers=8,
                              drop_last=True,
                              shuffle=True)
    test_loader = DataLoader(test_dataset,
                             batch_size=args.batch_size,
                             num_workers=8,
                             drop_last=True,
                             shuffle=False)

    # model
    model = RCF()
    model.cuda()
    model.apply(weights_init)
    load_vgg16pretrain(model)
    if args.resume:
        if isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}'".format(args.resume))
        else:
            raise Exception()
    else:
        raise Exception()

    #tune lr
    net_parameters_id = {}
    net = model
    for pname, p in net.named_parameters():
        if pname in [
                'conv1_1.weight', 'conv1_2.weight', 'conv2_1.weight',
                'conv2_2.weight', 'conv3_1.weight', 'conv3_2.weight',
                'conv3_3.weight', 'conv4_1.weight', 'conv4_2.weight',
                'conv4_3.weight'
        ]:
            print(pname, 'lr:1 de:1')
            if 'conv1-4.weight' not in net_parameters_id:
                net_parameters_id['conv1-4.weight'] = []
            net_parameters_id['conv1-4.weight'].append(p)
        elif pname in [
                'conv1_1.bias', 'conv1_2.bias', 'conv2_1.bias', 'conv2_2.bias',
                'conv3_1.bias', 'conv3_2.bias', 'conv3_3.bias', 'conv4_1.bias',
                'conv4_2.bias', 'conv4_3.bias'
        ]:
            print(pname, 'lr:2 de:0')
            if 'conv1-4.bias' not in net_parameters_id:
                net_parameters_id['conv1-4.bias'] = []
            net_parameters_id['conv1-4.bias'].append(p)
        elif pname in ['conv5_1.weight', 'conv5_2.weight', 'conv5_3.weight']:
            print(pname, 'lr:100 de:1')
            if 'conv5.weight' not in net_parameters_id:
                net_parameters_id['conv5.weight'] = []
            net_parameters_id['conv5.weight'].append(p)
        elif pname in ['conv5_1.bias', 'conv5_2.bias', 'conv5_3.bias']:
            print(pname, 'lr:200 de:0')
            if 'conv5.bias' not in net_parameters_id:
                net_parameters_id['conv5.bias'] = []
            net_parameters_id['conv5.bias'].append(p)
        elif pname in [
                'conv1_1_down.weight', 'conv1_2_down.weight',
                'conv2_1_down.weight', 'conv2_2_down.weight',
                'conv3_1_down.weight', 'conv3_2_down.weight',
                'conv3_3_down.weight', 'conv4_1_down.weight',
                'conv4_2_down.weight', 'conv4_3_down.weight',
                'conv5_1_down.weight', 'conv5_2_down.weight',
                'conv5_3_down.weight'
        ]:
            print(pname, 'lr:0.1 de:1')
            if 'conv_down_1-5.weight' not in net_parameters_id:
                net_parameters_id['conv_down_1-5.weight'] = []
            net_parameters_id['conv_down_1-5.weight'].append(p)
        elif pname in [
                'conv1_1_down.bias', 'conv1_2_down.bias', 'conv2_1_down.bias',
                'conv2_2_down.bias', 'conv3_1_down.bias', 'conv3_2_down.bias',
                'conv3_3_down.bias', 'conv4_1_down.bias', 'conv4_2_down.bias',
                'conv4_3_down.bias', 'conv5_1_down.bias', 'conv5_2_down.bias',
                'conv5_3_down.bias'
        ]:
            print(pname, 'lr:0.2 de:0')
            if 'conv_down_1-5.bias' not in net_parameters_id:
                net_parameters_id['conv_down_1-5.bias'] = []
            net_parameters_id['conv_down_1-5.bias'].append(p)
        elif pname in [
                'score_dsn1.weight', 'score_dsn2.weight', 'score_dsn3.weight',
                'score_dsn4.weight', 'score_dsn5.weight'
        ]:
            print(pname, 'lr:0.01 de:1')
            if 'score_dsn_1-5.weight' not in net_parameters_id:
                net_parameters_id['score_dsn_1-5.weight'] = []
            net_parameters_id['score_dsn_1-5.weight'].append(p)
        elif pname in [
                'score_dsn1.bias', 'score_dsn2.bias', 'score_dsn3.bias',
                'score_dsn4.bias', 'score_dsn5.bias'
        ]:
            print(pname, 'lr:0.02 de:0')
            if 'score_dsn_1-5.bias' not in net_parameters_id:
                net_parameters_id['score_dsn_1-5.bias'] = []
            net_parameters_id['score_dsn_1-5.bias'].append(p)
        elif pname in ['score_final.weight']:
            print(pname, 'lr:0.001 de:1')
            if 'score_final.weight' not in net_parameters_id:
                net_parameters_id['score_final.weight'] = []
            net_parameters_id['score_final.weight'].append(p)
        elif pname in ['score_final.bias']:
            print(pname, 'lr:0.002 de:0')
            if 'score_final.bias' not in net_parameters_id:
                net_parameters_id['score_final.bias'] = []
            net_parameters_id['score_final.bias'].append(p)

    optimizer = torch.optim.SGD([
        {
            'params': net_parameters_id['conv1-4.weight'],
            'lr': args.lr * 1,
            'weight_decay': args.weight_decay
        },
        {
            'params': net_parameters_id['conv1-4.bias'],
            'lr': args.lr * 2,
            'weight_decay': 0.
        },
        {
            'params': net_parameters_id['conv5.weight'],
            'lr': args.lr * 100,
            'weight_decay': args.weight_decay
        },
        {
            'params': net_parameters_id['conv5.bias'],
            'lr': args.lr * 200,
            'weight_decay': 0.
        },
        {
            'params': net_parameters_id['conv_down_1-5.weight'],
            'lr': args.lr * 0.1,
            'weight_decay': args.weight_decay
        },
        {
            'params': net_parameters_id['conv_down_1-5.bias'],
            'lr': args.lr * 0.2,
            'weight_decay': 0.
        },
        {
            'params': net_parameters_id['score_dsn_1-5.weight'],
            'lr': args.lr * 0.01,
            'weight_decay': args.weight_decay
        },
        {
            'params': net_parameters_id['score_dsn_1-5.bias'],
            'lr': args.lr * 0.02,
            'weight_decay': 0.
        },
        {
            'params': net_parameters_id['score_final.weight'],
            'lr': args.lr * 0.001,
            'weight_decay': args.weight_decay
        },
        {
            'params': net_parameters_id['score_final.bias'],
            'lr': args.lr * 0.002,
            'weight_decay': 0.
        },
    ],
                                lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    scheduler = lr_scheduler.StepLR(optimizer,
                                    step_size=args.stepsize,
                                    gamma=args.gamma)

    # log
    log = Logger(join(TMP_DIR, '%s-%d-log.txt' % ('sgd', args.lr)))
    sys.stdout = log

    for epoch in range(args.start_epoch, args.maxepoch):

        tr_avg_loss, tr_detail_loss = train(
            train_loader,
            model,
            optimizer,
            epoch,
            save_dir=join(TMP_DIR, 'epoch-%d-training-record' % epoch))

        # with torch.no_grad():
        #     # test(model, test_loader, epoch=epoch,
        #     #     save_dir = join(TMP_DIR, 'epoch-%d-testing-record-view' % epoch))

        #     # multiscale_test(model, test_loader, epoch=epoch,
        #     #     save_dir = join(TMP_DIR, 'epoch-%d-testing-record' % epoch))

        log.flush()  # write log

        # Save checkpoint
        save_file = os.path.join(TMP_DIR, 'checkpoint.pth')
        save_checkpoint(
            {
                'epoch': epoch,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict()
            },
            filename=save_file)

        scheduler.step()  # will adjust learning rate
예제 #4
0
def main():
    args.cuda = True
    # dataset
    train_dataset = BSDS_RCFLoader(root=args.dataset, split="train")
    test_dataset = BSDS_RCFLoader(root=args.dataset, split="test")
    train_loader = DataLoader(
        train_dataset, batch_size=args.batch_size,
        num_workers=8, drop_last=True,shuffle=True)
    test_loader = DataLoader(
        test_dataset, batch_size=args.batch_size,
        num_workers=8, drop_last=True,shuffle=False)


    with open('data/HED-BSDS_PASCAL/test.lst', 'r') as f:
    # with open('data/HED-BSDS/test.lst', 'r') as f:
        test_list = f.readlines()
    test_list = [split(i.rstrip())[1] for i in test_list]
    assert len(test_list) == len(test_loader), "%d vs %d" % (len(test_list), len(test_loader))

    # model
    model = RCF()
    model.cuda()
    model.apply(weights_init)
    load_vgg16pretrain(model,vgg_model_name)

    if args.resume:
        if isfile(args.resume): 
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}'"
                  .format(args.resume))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    
    #tune lr
    net_parameters_id = {}
    net = model
    for pname, p in net.named_parameters():
        if pname in ['conv1_1.weight','conv1_2.weight',
                     'conv2_1.weight','conv2_2.weight',
                     'conv3_1.weight','conv3_2.weight','conv3_3.weight',
                     'conv4_1.weight','conv4_2.weight','conv4_3.weight']:
            print(pname, 'lr:1 de:1')
            if 'conv1-4.weight' not in net_parameters_id:
                net_parameters_id['conv1-4.weight'] = []
            net_parameters_id['conv1-4.weight'].append(p)
        elif pname in ['conv1_1.bias','conv1_2.bias',
                       'conv2_1.bias','conv2_2.bias',
                       'conv3_1.bias','conv3_2.bias','conv3_3.bias',
                       'conv4_1.bias','conv4_2.bias','conv4_3.bias']:
            print(pname, 'lr:2 de:0')
            if 'conv1-4.bias' not in net_parameters_id:
                net_parameters_id['conv1-4.bias'] = []
            net_parameters_id['conv1-4.bias'].append(p)
        elif pname in ['conv5_1.weight','conv5_2.weight','conv5_3.weight']:
            print(pname, 'lr:100 de:1')
            if 'conv5.weight' not in net_parameters_id:
                net_parameters_id['conv5.weight'] = []
            net_parameters_id['conv5.weight'].append(p)
        elif pname in ['conv5_1.bias','conv5_2.bias','conv5_3.bias'] :
            print(pname, 'lr:200 de:0')
            if 'conv5.bias' not in net_parameters_id:
                net_parameters_id['conv5.bias'] = []
            net_parameters_id['conv5.bias'].append(p)
        elif pname in ['conv1_1_down.weight','conv1_2_down.weight',
                       'conv2_1_down.weight','conv2_2_down.weight',
                       'conv3_1_down.weight','conv3_2_down.weight','conv3_3_down.weight',
                       'conv4_1_down.weight','conv4_2_down.weight','conv4_3_down.weight',
                       'conv5_1_down.weight','conv5_2_down.weight','conv5_3_down.weight']:
            print(pname, 'lr:0.1 de:1')
            if 'conv_down_1-5.weight' not in net_parameters_id:
                net_parameters_id['conv_down_1-5.weight'] = []
            net_parameters_id['conv_down_1-5.weight'].append(p)
        elif pname in ['conv1_1_down.bias','conv1_2_down.bias',
                       'conv2_1_down.bias','conv2_2_down.bias',
                       'conv3_1_down.bias','conv3_2_down.bias','conv3_3_down.bias',
                       'conv4_1_down.bias','conv4_2_down.bias','conv4_3_down.bias',
                       'conv5_1_down.bias','conv5_2_down.bias','conv5_3_down.bias']:
            print(pname, 'lr:0.2 de:0')
            if 'conv_down_1-5.bias' not in net_parameters_id:
                net_parameters_id['conv_down_1-5.bias'] = []
            net_parameters_id['conv_down_1-5.bias'].append(p)
        elif pname in ['score_dsn1.weight','score_dsn2.weight','score_dsn3.weight',
                       'score_dsn4.weight','score_dsn5.weight']:
            print(pname, 'lr:0.01 de:1')
            if 'score_dsn_1-5.weight' not in net_parameters_id:
                net_parameters_id['score_dsn_1-5.weight'] = []
            net_parameters_id['score_dsn_1-5.weight'].append(p)
        elif pname in ['score_dsn1.bias','score_dsn2.bias','score_dsn3.bias',
                       'score_dsn4.bias','score_dsn5.bias']:
            print(pname, 'lr:0.02 de:0')
            if 'score_dsn_1-5.bias' not in net_parameters_id:
                net_parameters_id['score_dsn_1-5.bias'] = []
            net_parameters_id['score_dsn_1-5.bias'].append(p)
        elif pname in ['score_final.weight']:
            print(pname, 'lr:0.001 de:1')
            if 'score_final.weight' not in net_parameters_id:
                net_parameters_id['score_final.weight'] = []
            net_parameters_id['score_final.weight'].append(p)
        elif pname in ['score_final.bias']:
            print(pname, 'lr:0.002 de:0')
            if 'score_final.bias' not in net_parameters_id:
                net_parameters_id['score_final.bias'] = []
            net_parameters_id['score_final.bias'].append(p)

    optimizer = torch.optim.SGD([
            {'params': net_parameters_id['conv1-4.weight']      , 'lr': args.lr*1    , 'weight_decay': args.weight_decay},
            {'params': net_parameters_id['conv1-4.bias']        , 'lr': args.lr*2    , 'weight_decay': 0.},
            {'params': net_parameters_id['conv5.weight']        , 'lr': args.lr*100  , 'weight_decay': args.weight_decay},
            {'params': net_parameters_id['conv5.bias']          , 'lr': args.lr*200  , 'weight_decay': 0.},
            {'params': net_parameters_id['conv_down_1-5.weight'], 'lr': args.lr*0.1  , 'weight_decay': args.weight_decay},
            {'params': net_parameters_id['conv_down_1-5.bias']  , 'lr': args.lr*0.2  , 'weight_decay': 0.},
            {'params': net_parameters_id['score_dsn_1-5.weight'], 'lr': args.lr*0.01 , 'weight_decay': args.weight_decay},
            {'params': net_parameters_id['score_dsn_1-5.bias']  , 'lr': args.lr*0.02 , 'weight_decay': 0.},
            {'params': net_parameters_id['score_final.weight']  , 'lr': args.lr*0.001, 'weight_decay': args.weight_decay},
            {'params': net_parameters_id['score_final.bias']    , 'lr': args.lr*0.002, 'weight_decay': 0.},
        ], lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
    scheduler = lr_scheduler.StepLR(optimizer, step_size=args.stepsize, gamma=args.gamma)



    # optimizer = torch.optim.Adam([
    #         {'params': net_parameters_id['conv1-4.weight']      , 'lr': args.lr*1    , 'weight_decay': args.weight_decay},
    #         {'params': net_parameters_id['conv1-4.bias']        , 'lr': args.lr*2    , 'weight_decay': 0.},
    #         {'params': net_parameters_id['conv5.weight']        , 'lr': args.lr*100  , 'weight_decay': args.weight_decay},
    #         {'params': net_parameters_id['conv5.bias']          , 'lr': args.lr*200  , 'weight_decay': 0.},
    #         {'params': net_parameters_id['conv_down_1-5.weight'], 'lr': args.lr*0.1  , 'weight_decay': args.weight_decay},
    #         {'params': net_parameters_id['conv_down_1-5.bias']  , 'lr': args.lr*0.2  , 'weight_decay': 0.},
    #         {'params': net_parameters_id['score_dsn_1-5.weight'], 'lr': args.lr*0.01 , 'weight_decay': args.weight_decay},
    #         {'params': net_parameters_id['score_dsn_1-5.bias']  , 'lr': args.lr*0.02 , 'weight_decay': 0.},
    #         {'params': net_parameters_id['score_final.weight']  , 'lr': args.lr*0.001, 'weight_decay': args.weight_decay},
    #         {'params': net_parameters_id['score_final.bias']    , 'lr': args.lr*0.002, 'weight_decay': 0.},
    #     ], lr=args.lr, betas=(0.9, 0.99), weight_decay=args.weight_decay)
    # scheduler = lr_scheduler.StepLR(optimizer, step_size=args.stepsize, gamma=args.gamma)

    # log
    log = Logger(join(TMP_DIR, '%s-%d-log.txt' %('sgd',args.lr)))
    sys.stdout = log

    train_loss = []
    train_loss_detail = []

    print("The start is epoch", args.start_epoch)
    print("The max is epoch",args.maxepoch)
    # for epoch in range(args.start_epoch, args.maxepoch):
    #     if epoch == 0:
    #         print("Performing initial testing...")
    #         multiscale_test(model, test_loader, epoch=epoch, test_list=test_list,
    #              save_dir = join(TMP_DIR, 'initial-testing-record'))
    #
    #     tr_avg_loss, tr_detail_loss = train(
    #         train_loader, model, optimizer, epoch,
    #         save_dir = join(TMP_DIR, 'epoch-%d-training-record' % epoch))
    #
    #     test(model, test_loader, epoch=epoch, test_list=test_list,
    #         save_dir = join(TMP_DIR, 'epoch-%d-testing-record-view' % epoch))
    #
    #     multiscale_test(model, test_loader, epoch=epoch, test_list=test_list,
    #         save_dir = join(TMP_DIR, 'epoch-%d-testing-record' % epoch))
    #
    #     log.flush() # write log
    #     # Save checkpoint
    #     save_file = os.path.join(TMP_DIR, 'checkpoint_epoch{}.pth'.format(epoch))
    #     save_checkpoint({
    #         'epoch': epoch,
    #         'state_dict': model.state_dict(),
    #         'optimizer': optimizer.state_dict()
    #                      }, filename=save_file)
    #     scheduler.step() # will adjust learning rate
        ## save train/val loss/accuracy, save every epoch in case of early stop
        # train_loss.append(tr_avg_loss)
        # train_loss_detail += tr_detail_loss


    #Testing the pretraind model over the test images
    checkpoint=torch.load("RCFcheckpoint_epoch12.pth")
    model.load_state_dict(checkpoint['state_dict'])
    epoch =0
    # print("Performing initial testing...")
    # multiscale_test(model, test_loader, epoch=epoch, test_list=test_list,
    #              save_dir = join(TMP_DIR, 'initial-testing-record'))
    # test(model, test_loader, epoch=epoch, test_list=test_list,
    #         save_dir = join(TMP_DIR, 'epoch-%d-testing-record-view' % epoch))
    #

    #########################
    #Test for our dataset!
    ########################
    vid_name = 'StudySpot'
    # # Read the video from specified path
    # cam = cv2.VideoCapture('data/Book_Occ3.MOV')
    # try:
    #     # creating a folder named data
    #     if not os.path.exists('data/'+vid_name+'/test'):
    #         os.makedirs('data/'+vid_name+'/test')
    #         # if not created then raise error
    # except OSError:
    #     print('Error: Creating directory of data')
    #     # frame
    # currentframe = 0
    # fil = open('data/'+vid_name+'/test.lst', "a+")
    # width_p = 481
    # height_p = 321
    #
    # while (True):
    #     # reading from frame
    #     ret, frame = cam.read()
    #     if ret:
    #         # if video is still left continue creating images
    #         name = 'data/' +vid_name+'/test/'+ str(currentframe) + '.jpg'
    #         print('Creating...' + name)
    #         # writing the extracted images
    #         frame = cv2.resize(frame, (width_p, height_p))
    #         cv2.imwrite(name, frame)
    #         fil.write('test/'+ str(currentframe) + '.jpg\n')
    #         # increasing counter so that it will
    #         # show how many frames are created
    #         currentframe += 1
    #
    #     else:
    #         break
    #
    # # Release all space and windows once done
    # cam.release()
    # fil.close()
    # cv2.destroyAllWindows()

    test_dataset = BSDS_RCFLoader(root='data/'+vid_name, split="test")
    print(test_dataset.filelist)

    test_loader = DataLoader(
        test_dataset, batch_size=args.batch_size,
        num_workers=8, drop_last=True, shuffle=False)

    with open('data/'+vid_name+'/test.lst', 'r') as f:
        test_list = f.readlines()
    test_list = [split(i.rstrip())[1] for i in test_list]
    assert len(test_list) == len(test_loader), "%d vs %d" % (len(test_list), len(test_loader))


    epoch = 0
    print("Performing testing...")
    # multiscale_test(model, test_loader, epoch=epoch, test_list=test_list,
    #              save_dir = join(TMP_DIR,vid_name,'initial-testing-record'))
    test(model, test_loader, epoch=epoch, test_list=test_list,
            save_dir = join(TMP_DIR,vid_name, 'epoch-%d-testing-record-view' % epoch))
예제 #5
0
def main():
    args.cuda = True
    # dataset
    train_dataset = BSDSLoader(root=args.dataset, split="train")
    test_dataset = BSDSLoader(root=args.dataset, split="test")
    train_loader = DataLoader(
        train_dataset, batch_size=args.batch_size,
        num_workers=4, drop_last=True,shuffle=True)
    test_loader = DataLoader(
        test_dataset, batch_size=args.batch_size,
        num_workers=4, drop_last=True,shuffle=False)
    with open(join(args.dataset, 'test.lst'), 'r') as f:
        test_list = f.readlines()
    test_list = [split(i.rstrip())[1] for i in test_list]
    assert len(test_list) == len(test_loader), "%d vs %d" % (len(test_list), len(test_loader))

    # default hyperparameters
    if args.use_cfg:
        if args.pretrained and not args.small:
            args.stepsize = 2
            args.lr = 0.001 if args.harmonic else 0.0002
        elif args.small:
            args.stepsize = 6
            args.lr = 0.005 if args.harmonic else 0.001
        else:
            args.stepsize = 4
            args.lr = 0.0005 if args.harmonic else 0.0002
        args.maxepoch = args.stepsize + 1
        
    # model
    model = HEDSmall(harmonic=args.harmonic) if args.small else HED(harmonic=args.harmonic)
    model.cuda()
    model.apply(weights_init)
    if args.pretrained and not args.small:
        if args.harmonic:    
            load_harm_vgg16pretrain(model)
        else:
            load_vgg16pretrain(model)
    
    #tune lr
    net_parameters_id = {}
    
    if args.pretrained and not args.small:
        for pname, p in model.named_parameters():
            if pname in ['conv1_1.weight','conv1_2.weight',
                         'conv2_1.weight','conv2_2.weight',
                         'conv3_1.weight','conv3_2.weight','conv3_3.weight',
                         'conv4_1.weight','conv4_2.weight','conv4_3.weight',
                         'conv5_1.weight','conv5_2.weight','conv5_3.weight']:
                print(pname, 'lr:1 de:1')
                if 'conv1-5.weight' not in net_parameters_id:
                    net_parameters_id['conv1-5.weight'] = []
                net_parameters_id['conv1-5.weight'].append(p)
            elif pname in ['conv1_1.bias','conv1_2.bias',
                           'conv2_1.bias','conv2_2.bias',
                           'conv3_1.bias','conv3_2.bias','conv3_3.bias',
                           'conv4_1.bias','conv4_2.bias','conv4_3.bias',
                           'conv5_1.bias','conv5_2.bias','conv5_3.bias']:
                print(pname, 'lr:2 de:0')
                if 'conv1-5.bias' not in net_parameters_id:
                    net_parameters_id['conv1-5.bias'] = []
                net_parameters_id['conv1-5.bias'].append(p)     
            elif pname in ['score_dsn1.weight','score_dsn2.weight','score_dsn3.weight',
                           'score_dsn4.weight','score_dsn5.weight']:
                print(pname, 'lr:0.01 de:1')
                if 'score_dsn_1-5.weight' not in net_parameters_id:
                    net_parameters_id['score_dsn_1-5.weight'] = []
                net_parameters_id['score_dsn_1-5.weight'].append(p)
            elif pname in ['score_dsn1.bias','score_dsn2.bias','score_dsn3.bias',
                           'score_dsn4.bias','score_dsn5.bias']:
                print(pname, 'lr:0.02 de:0')
                if 'score_dsn_1-5.bias' not in net_parameters_id:
                    net_parameters_id['score_dsn_1-5.bias'] = []
                net_parameters_id['score_dsn_1-5.bias'].append(p)
            elif pname in ['score_final.weight']:
                print(pname, 'lr:0.001 de:1')
                if 'score_final.weight' not in net_parameters_id:
                    net_parameters_id['score_final.weight'] = []
                net_parameters_id['score_final.weight'].append(p)
            elif pname in ['score_final.bias']:
                print(pname, 'lr:0.002 de:0')
                if 'score_final.bias' not in net_parameters_id:
                    net_parameters_id['score_final.bias'] = []
                net_parameters_id['score_final.bias'].append(p)
        param_groups = [
                {'params': net_parameters_id['conv1-5.weight']      , 'lr': args.lr*1    , 'weight_decay': args.weight_decay},
                {'params': net_parameters_id['conv1-5.bias']        , 'lr': args.lr*2    , 'weight_decay': 0.},
                {'params': net_parameters_id['score_dsn_1-5.weight'], 'lr': args.lr*0.01 , 'weight_decay': args.weight_decay},
                {'params': net_parameters_id['score_dsn_1-5.bias']  , 'lr': args.lr*0.02 , 'weight_decay': 0.},
                {'params': net_parameters_id['score_final.weight']  , 'lr': args.lr*0.001, 'weight_decay': args.weight_decay},
                {'params': net_parameters_id['score_final.bias']    , 'lr': args.lr*0.002, 'weight_decay': 0.}
            ]
    else:
        net_parameters_id = {'weights': [], 'biases': []}
        for pname, p in model.named_parameters():
            if 'weight' in pname:
                net_parameters_id['weights'].append(p)
            elif 'bias' in pname:
                net_parameters_id['biases'].append(p)
        param_groups = [
                {'params': net_parameters_id['weights'], 'weight_decay': args.weight_decay},
                {'params': net_parameters_id['biases'], 'weight_decay': 0.}
            ]

    optimizer = torch.optim.Adam(param_groups, lr=args.lr, weight_decay=args.weight_decay)
    scheduler = lr_scheduler.StepLR(optimizer, step_size=args.stepsize, gamma=args.gamma)
    

    if args.resume:
        if isfile(args.resume): 
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}'"
                  .format(args.resume))
            optimizer.load_state_dict(checkpoint['optimizer'])
            args.start_epoch = checkpoint['epoch']
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    # log
    log = Logger(join(OUT_DIR, 'log.txt'))
    sys.stdout = log

    train_loss = []
    train_loss_detail = []
    for epoch in range(args.start_epoch, args.maxepoch):
        if epoch == 0:
            print("Performing initial testing...")
            test(model, test_loader, epoch=epoch, test_list=test_list,
                 save_dir = join(OUT_DIR, 'initial-testing-record'))

        tr_avg_loss, tr_detail_loss = train(
            train_loader, model, optimizer, epoch,
            save_dir = join(OUT_DIR, 'epoch-%d-training-record' % epoch))
        test(model, test_loader, epoch=epoch, test_list=test_list,
            save_dir = join(OUT_DIR, 'epoch-%d-testing-record' % epoch))
        log.flush() # write log
        # Save checkpoint
        save_file = os.path.join(OUT_DIR, 'checkpoint_epoch{}.pth'.format(epoch))
        save_checkpoint({
            'epoch': epoch,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict()
                         }, filename=save_file)
        scheduler.step() # will adjust learning rate
        # save train/val loss/accuracy, save every epoch in case of early stop
        train_loss.append(tr_avg_loss)
        train_loss_detail += tr_detail_loss