def main(args):
    if args.dataset == "cityscapes":
        train_dataset = DatasetTrain(cityscapes_data_path="/home/chenxiaoshuang/Cityscapes",
                                    cityscapes_meta_path="/home/chenxiaoshuang/Cityscapes/gtFine", 
                                    only_encode=args.only_encode, extra_data=args.extra_data)
        val_dataset = DatasetVal(cityscapes_data_path="/home/chenxiaoshuang/Cityscapes",
                                cityscapes_meta_path="/home/chenxiaoshuang/Cityscapes/gtFine",
                                only_encode=args.only_encode)
        test_dataset = DatasetTest(cityscapes_data_path="/home/chenxiaoshuang/Cityscapes",
                                cityscapes_meta_path="/home/chenxiaoshuang/Cityscapes/gtFine")      
        train_loader = DataLoader(dataset=train_dataset,
                                batch_size=args.batch_size, shuffle=True, num_workers=8, drop_last=True)
        val_loader = DataLoader(dataset=val_dataset,
                                batch_size=args.batch_size, shuffle=False, num_workers=8)
        test_loader = DataLoader(dataset=test_dataset,
                                batch_size=args.batch_size, shuffle=False, num_workers=8)
        num_classes = 20
    elif args.dataset == "camvid":
        train_dataset = DatasetCamVid(camvid_data_path="/home/chenxiaoshuang/CamVid",
                                    camvid_meta_path="/home/chenxiaoshuang/CamVid",
                                    only_encode=args.only_encode, mode="train")
        val_dataset = DatasetCamVid(camvid_data_path="/home/chenxiaoshuang/CamVid",
                                    camvid_meta_path="/home/chenxiaoshuang/CamVid",
                                    only_encode=args.only_encode, mode="val")
        test_dataset = DatasetCamVid(camvid_data_path="/home/chenxiaoshuang/CamVid",
                                    camvid_meta_path="/home/chenxiaoshuang/CamVid",
                                    only_encode=args.only_encode, mode="test")
        train_loader = DataLoader(dataset=train_dataset,
                                batch_size=args.batch_size, shuffle=True, num_workers=8, drop_last=True)
        val_loader = DataLoader(dataset=val_dataset,
                                batch_size=args.batch_size, shuffle=False, num_workers=8)
        test_loader = DataLoader(dataset=test_dataset,
                                batch_size=args.batch_size, shuffle=False, num_workers=8)
        num_classes = 12
    else:
        print("Unsupported Dataset!")
        return

    device = torch.device("cuda:{}".format(args.cuda) if torch.cuda.is_available() else "cpu")
    device_ids = [args.cuda, args.cuda+1]
    cfg=Config(args.dataset, args.only_encode, args.extra_data)
    net = Net(num_classes=num_classes)
    
    if torch.cuda.is_available():
        weight = cfg.weight.to(device)
    criterion1 = CrossEntropyLoss2d(weight)
    criterion2 = LovaszSoftmax(weight=weight)
    
    optimizer = optim.Adam(net.parameters(), 5e-4, (0.9, 0.999),  eps=1e-08, weight_decay=1e-4)

    lambda1 = lambda epoch : (1 - epoch/300) ** 0.9

    exp_lr_scheduler = lr_scheduler.LambdaLR(optimizer,lr_lambda=lambda1)
    
    trainer = Trainer('training', optimizer, exp_lr_scheduler, net, cfg, './log', device, device_ids, num_classes)
    trainer.load_weights(trainer.find_last(), encode=False, restart=False)
    #trainer.train(train_loader, val_loader, criterion1, criterion2, 300)
    trainer.evaluate(val_loader)
    trainer.test(test_loader)
    
    print('Finished Training')
示例#2
0
from loss import CrossEntropyLoss2d


if __name__=='__main__':

    cfg=Config()
    #create dataset
    train_dataset = DatasetTrain(cityscapes_data_path="/home/shen/Data/DataSet/Cityscape",
                                cityscapes_meta_path="/home/shen/Data/DataSet/Cityscape/gtFine/")
    val_dataset = DatasetVal(cityscapes_data_path="/home/shen/Data/DataSet/Cityscape",
                             cityscapes_meta_path="/home/shen/Data/DataSet/Cityscape/gtFine")       
    train_loader = DataLoader(dataset=train_dataset,
                                           batch_size=10, shuffle=True,
                                           num_workers=8)
    val_loader = DataLoader(dataset=val_dataset,
                                         batch_size=12, shuffle=False, 
                                         num_workers=8)
    net = xceptionAx3(num_classes=20)
    #load loss
    criterion = CrossEntropyLoss()
    optimizer = optim.SGD(
    net.parameters(), lr=0.5, momentum=0.9,weight_decay=0.00001)  #select the optimizer

    lr_fc=lambda iteration: (1-iteration/160000)**0.9
    exp_lr_scheduler = lr_scheduler.LambdaLR(optimizer,lr_fc,-1)
    trainer = Trainer('training', optimizer,exp_lr_scheduler, net, cfg, './log')
    trainer.load_weights(trainer.find_last())
    trainer.train(train_loader, val_loader, criterion, 640)
    #trainer.evaluate(valid_loader)
    print('Finished Training')
示例#3
0
cfig = Config()  #config for the model
net = se_resnet_18()  #create CNN model.
criterion = nn.BCELoss()  #define the loss
optimizer = optim.SGD(net.parameters(), lr=0.001,
                      momentum=0.9)  #select the optimizer
#lr would divide gamma for every step_size,if you  schedult it by the class lr_scheduler.
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.1)
# create the train_dataset_loader and val_dataset_loader.
train_tarnsformed_dataset = CloudDataset(img_dir=img_dir,
                                         labels_dir=labels_dir,
                                         transform=transforms.Compose(
                                             [ToTensor()]))
val_tarnsformed_dataset = CloudDataset(img_dir=img_dir,
                                       labels_dir=labels_dir,
                                       val=True,
                                       transform=transforms.Compose(
                                           [ToTensor()]))
train_dataloader = DataLoader(train_tarnsformed_dataset,
                              batch_size=8,
                              shuffle=True,
                              num_workers=4)
val_dataloader = DataLoader(val_tarnsformed_dataset,
                            batch_size=8,
                            shuffle=True,
                            num_workers=4)
trainer = Trainer('training', optimizer, exp_lr_scheduler, net, cfig, './log')
trainer.load_weights(trainer.find_last())  #加载最新的模型,并基于此模型继续训练。
trainer.train(train_dataloader, val_dataloader, criterion, epochs)
print('Finished Training')