예제 #1
0
def main():
    #torch.manual_seed(42)

    # ------------
    # args
    # ------------
    parser = ArgumentParser()
    parser.add_argument('--auto_lr',
                        type=U.str2bool,
                        default=False,
                        help="Auto lr finder")
    parser.add_argument('--learning_rate', type=float, default=10e-4)
    parser.add_argument('--scheduler', type=U.str2bool, default=False)
    parser.add_argument('--wd', type=float, default=2e-4)
    parser.add_argument('--moment', type=float, default=0.9)
    parser.add_argument('--batch_size', default=5, type=int)
    parser.add_argument('--n_epochs', default=10, type=int)
    parser.add_argument('--model',
                        default='FCN',
                        type=str,
                        help="FCN or DLV3 model")
    parser.add_argument('--pretrained',
                        default=False,
                        type=U.str2bool,
                        help="Use pretrained pytorch model")
    parser.add_argument('--eval_angle', default=True, type=U.str2bool,help=\
        "If true, it'll eval the model with different angle input size")
    parser.add_argument('--rotate',
                        default=False,
                        type=U.str2bool,
                        help="Use random rotation as data augmentation")
    parser.add_argument('--scale',
                        default=True,
                        type=U.str2bool,
                        help="Use scale as data augmentation")
    parser.add_argument('--size_img',
                        default=520,
                        type=int,
                        help="Size of input images")
    parser.add_argument('--size_crop',
                        default=480,
                        type=int,
                        help="Size of crop image during training")
    parser.add_argument('--nw',
                        default=0,
                        type=int,
                        help="Num workers for the data loader")
    parser.add_argument('--pm',
                        default=True,
                        type=U.str2bool,
                        help="Pin memory for the dataloader")
    parser.add_argument('--gpu',
                        default=0,
                        type=int,
                        help="Wich gpu to select for training")
    parser.add_argument('--benchmark',
                        default=False,
                        type=U.str2bool,
                        help="enable or disable backends.cudnn")
    parser.add_argument('--split',
                        default=False,
                        type=U.str2bool,
                        help="Split the dataset")
    parser.add_argument('--split_ratio',
                        default=0.3,
                        type=float,
                        help="Amount of data we used for training")
    parser.add_argument('--dataroot_voc',
                        default='/share/DEEPLEARNING/datasets/voc2012/',
                        type=str)
    parser.add_argument('--dataroot_sbd',
                        default='/share/DEEPLEARNING/datasets/sbd/',
                        type=str)
    parser.add_argument('--model_name',
                        type=str,
                        help="what name to use for saving")
    parser.add_argument('--save_dir', default='/data/save_model', type=str)
    parser.add_argument('--save_all_ep', default=False, type=U.str2bool,help=\
        "If true it'll save the model every epoch in save_dir")
    parser.add_argument('--save_best',
                        default=False,
                        type=U.str2bool,
                        help="If true will only save the best epoch model")
    args = parser.parse_args()
    # ------------
    # save
    # ------------
    save_dir = U.create_save_directory(args.save_dir)
    print('model will be saved in', save_dir)
    U.save_hparams(args, save_dir)
    # ------------
    # device
    # ------------
    device = torch.device(
        "cuda:" + str(args.gpu) if torch.cuda.is_available() else "cpu")
    print("device used:", device)
    # ------------
    # model
    # ------------

    if args.model.upper() == 'FCN':
        model = models.segmentation.fcn_resnet101(pretrained=args.pretrained)
    elif args.model.upper() == 'DLV3':
        model = models.segmentation.deeplabv3_resnet101(
            pretrained=args.pretrained)
    else:
        raise Exception('model must be "FCN" or "DLV3"')
    model.to(device)
    # ------------
    # data
    # ------------
    if args.size_img < args.size_crop:
        raise Exception(
            'Cannot have size of input images less than size of crop')
    size_img = (args.size_img, args.size_img)
    size_crop = (args.size_crop, args.size_crop)
    train_dataset_VOC = mdset.VOCSegmentation(args.dataroot_voc,year='2012', image_set='train', \
        download=True,rotate=args.rotate,scale=args.scale,size_img=size_img,size_crop=size_crop)
    val_dataset_VOC = mdset.VOCSegmentation(args.dataroot_voc,
                                            year='2012',
                                            image_set='val',
                                            download=True)
    train_dataset_SBD = mdset.SBDataset(args.dataroot_sbd, image_set='train_noval',mode='segmentation',\
        rotate=args.rotate,scale=args.scale,size_img=size_img,size_crop=size_crop)
    # Concatene dataset
    train_dataset = tud.ConcatDataset([train_dataset_VOC, train_dataset_SBD])
    split = args.split
    if split == True:
        train_dataset = U.split_dataset(train_dataset, args.split_ratio)
    # Print len datasets
    print("There is", len(train_dataset), "images for training and",
          len(val_dataset_VOC), "for validation")
    dataloader_train = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size,num_workers=args.nw,\
        pin_memory=args.pm,shuffle=True,drop_last=True)#,collate_fn=U.my_collate)
    dataloader_val = torch.utils.data.DataLoader(val_dataset_VOC,num_workers=args.nw,pin_memory=args.pm,\
        batch_size=args.batch_size)
    # Decide which device we want to run on

    # ------------
    # training
    # ------------
    # Auto lr finding
    #if args.auto_lr==True:

    criterion = nn.CrossEntropyLoss(
        ignore_index=21)  # On ignore la classe border.
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.learning_rate,
                                momentum=args.moment,
                                weight_decay=args.wd)
    ev.train_fully_supervised(model=model,n_epochs=args.n_epochs,train_loader=dataloader_train,val_loader=dataloader_val,\
        criterion=criterion,optimizer=optimizer,save_folder=save_dir,scheduler=args.scheduler,model_name=args.model_name,\
            benchmark=args.benchmark, save_best=args.save_best,save_all_ep=args.save_all_ep,device=device,num_classes=21)

    # Final evaluation
    if args.eval_angle:
        d_iou = ev.eval_model_all_angle(model,
                                        args.size_img,
                                        args.dataroot_voc,
                                        train=True,
                                        device=device)
        U.save_eval_angle(d_iou, save_dir)
        d_iou = ev.eval_model_all_angle(model,
                                        args.size_img,
                                        args.dataroot_voc,
                                        train=False,
                                        device=device)
        U.save_eval_angle(d_iou, save_dir)
def main():
    #torch.manual_seed(42)

    # ------------
    # args
    # ------------
    parser = ArgumentParser()

    # Learning parameters
    parser.add_argument('--auto_lr', type=U.str2bool, default=False,help="Auto lr finder")
    parser.add_argument('--learning_rate', type=float, default=10e-4)
    parser.add_argument('--scheduler', type=U.str2bool, default=False)
    parser.add_argument('--wd', type=float, default=2e-4)
    parser.add_argument('--moment', type=float, default=0.9)
    parser.add_argument('--batch_size', default=5, type=int)
    parser.add_argument('--n_epochs', default=10, type=int)
    parser.add_argument('--iter_every', default=1, type=int,help="Accumulate compute graph for iter_size step")
    parser.add_argument('--benchmark', default=False, type=U.str2bool, help="enable or disable backends.cudnn")
    
    # Model and eval
    parser.add_argument('--model', default='FCN', type=str,help="FCN or DLV3 model")
    parser.add_argument('--pretrained', default=False, type=U.str2bool,help="Use pretrained pytorch model")
    parser.add_argument('--eval_angle', default=True, type=U.str2bool,help=\
        "If true, it'll eval the model with different angle input size")
    
    
    # Data augmentation
    parser.add_argument('--rotate', default=False, type=U.str2bool,help="Use random rotation as data augmentation")
    parser.add_argument('--pi_rotate', default=True, type=U.str2bool,help="Use only pi/2 rotation angle")
    parser.add_argument('--p_rotate', default=0.25, type=float,help="Probability of rotating the image during the training")
    parser.add_argument('--scale', default=True, type=U.str2bool,help="Use scale as data augmentation")
    parser.add_argument('--landcover', default=False, type=U.str2bool,\
         help="Use Landcover dataset instead of VOC and COCO")
    parser.add_argument('--size_img', default=520, type=int,help="Size of input images")
    parser.add_argument('--size_crop', default=480, type=int,help="Size of crop image during training")
    parser.add_argument('--angle_max', default=360, type=int,help="Angle max for data augmentation")
    
    # Dataloader and gpu
    parser.add_argument('--nw', default=0, type=int,help="Num workers for the data loader")
    parser.add_argument('--pm', default=True, type=U.str2bool,help="Pin memory for the dataloader")
    parser.add_argument('--gpu', default=0, type=int,help="Wich gpu to select for training")
    
    # Datasets 
    parser.add_argument('--split', default=False, type=U.str2bool, help="Split the dataset")
    parser.add_argument('--split_ratio', default=0.3, type=float, help="Amount of data we used for training")
    parser.add_argument('--dataroot_voc', default='/data/voc2012', type=str)
    parser.add_argument('--dataroot_sbd', default='/data/sbd', type=str)
    parser.add_argument('--dataroot_landcover', default='/share/DEEPLEARNING/datasets/landcover', type=str)
    
    # Save parameters
    parser.add_argument('--model_name', type=str,help="what name to use for saving")
    parser.add_argument('--save_dir', default='/data/save_model', type=str)
    parser.add_argument('--save_all_ep', default=False, type=U.str2bool,help=\
        "If true it'll save the model every epoch in save_dir")
    parser.add_argument('--save_best', default=False, type=U.str2bool,help="If true will only save the best epoch model")
    args = parser.parse_args()
    
    # ------------
    # device
    # ------------
    device = torch.device("cuda:"+str(args.gpu) if torch.cuda.is_available() else "cpu")
    print("device used:",device)
    
    # ------------
    # data
    # ------------
    if args.size_img < args.size_crop:
        raise Exception('Cannot have size of input images less than size of crop')
    size_img = (args.size_img,args.size_img)
    size_crop = (args.size_crop,args.size_crop)
    if not args.landcover:
        train_dataset_VOC = mdset.VOCSegmentation(args.dataroot_voc,year='2012', image_set='train', \
            download=True,rotate=args.rotate,size_img=size_img,size_crop=size_crop)
        test_dataset = mdset.VOCSegmentation(args.dataroot_voc,year='2012', image_set='val', download=True)
        train_dataset_SBD = mdset.SBDataset(args.dataroot_sbd, image_set='train_noval',mode='segmentation',\
            rotate=args.rotate,size_img=size_img,size_crop=size_crop)
        #COCO dataset 
        if args.extra_coco:
            extra_COCO = cu.get_coco(args.dataroot_coco,'train',rotate=args.rotate,size_img=size_img,size_crop=size_crop)
            # Concatene dataset
            train_dataset = tud.ConcatDataset([train_dataset_VOC,train_dataset_SBD,extra_COCO])
        else:
            train_dataset = tud.ConcatDataset([train_dataset_VOC,train_dataset_SBD])
        num_classes = 21
    else:
        print('Loading Landscape Dataset')
        train_dataset = mdset.LandscapeDataset(args.dataroot_landcover,image_set="trainval",\
            rotate=args.rotate,pi_rotate=args.pi_rotate,p_rotate=args.p_rotate,size_img=size_img,size_crop=size_crop,angle_max=args.angle_max)
        test_dataset = mdset.LandscapeDataset(args.dataroot_landcover,image_set="test")
        print('Success load Landscape Dataset')
        num_classes = 4
    
    split = args.split
    if split==True:
        train_dataset = U.split_dataset(train_dataset,args.split_ratio)
    # Print len datasets
    print("There is",len(train_dataset),"images for training and",len(test_dataset),"for validation")
    dataloader_train = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size,num_workers=args.nw,\
        pin_memory=args.pm,shuffle=True,drop_last=True)#,collate_fn=U.my_collate)
    dataloader_val = torch.utils.data.DataLoader(test_dataset,num_workers=args.nw,pin_memory=args.pm,\
        batch_size=args.batch_size)

    
    # ------------
    # model
    # ------------
    
    if args.model.upper()=='FCN':
        model = models.segmentation.fcn_resnet101(pretrained=args.pretrained,num_classes=num_classes)
    elif args.model.upper()=='DLV3':
        model = models.segmentation.deeplabv3_resnet101(pretrained=args.pretrained,num_classes=num_classes)
    else:
        raise Exception('model must be "FCN" or "DLV3"')
    #model.to(device)

    
    # ------------
    # save
    # ------------
    save_dir = U.create_save_directory(args.save_dir)
    print('model will be saved in',save_dir)
    U.save_hparams(args,save_dir)

    # ------------
    # training
    # ------------
    # Auto lr finding
    print(args)
    
    criterion = nn.CrossEntropyLoss(ignore_index=num_classes) # On ignore la classe border.
    torch.autograd.set_detect_anomaly(True)
    optimizer = torch.optim.SGD(model.parameters(),lr=args.learning_rate,momentum=args.moment,weight_decay=args.wd)
    
    ev.train_fully_supervised(model=model,n_epochs=args.n_epochs,train_loader=dataloader_train,val_loader=dataloader_val,\
        criterion=criterion,optimizer=optimizer,save_folder=save_dir,scheduler=args.scheduler,auto_lr=args.auto_lr,\
            model_name=args.model_name,benchmark=args.benchmark, save_best=args.save_best,save_all_ep=args.save_all_ep,\
                device=device,num_classes=num_classes)
예제 #3
0
#scale_factor = (0.2,0.8)
#size_img = (420,420)
#size_crop = (380,380)

# DEVICE
# Decide which device we want to run on
device = torch.device("cuda:" + str(1) if torch.cuda.is_available() else "cpu")
print("device :", device)

model_dir = '/share/homes/karmimy/equiv/save_model/rot_equiv/49'

model = torch.load(join(model_dir, model_name), map_location=device)

if VOC:
    num_classes = 21
    train_dataset_VOC = mdset.VOCSegmentation(dataroot_voc,year='2012', image_set='train', \
            download=True,rotate=rotate,scale=scale,size_img=size_img,size_crop=size_crop)
    test_dataset = mdset.VOCSegmentation(dataroot_voc,
                                         year='2012',
                                         image_set='val',
                                         download=True)
    train_dataset_SBD = mdset.SBDataset(dataroot_sbd, image_set='train_noval',mode='segmentation',\
            rotate=rotate,scale=scale,size_img=size_img,size_crop=size_crop)
    train_dataset = tud.ConcatDataset([train_dataset_VOC, train_dataset_SBD])

else:
    num_classes = 4
    print('Loading Landscape Dataset')
    train_dataset = mdset.LandscapeDataset(dataroot_landcover,image_set="trainval",\
        rotate=rotate)#,size_img=size_img,size_crop=size_crop)
    test_dataset = mdset.LandscapeDataset(dataroot_landcover, image_set="test")
    test_dataset_no_norm = mdset.LandscapeDataset(dataroot_landcover,
예제 #4
0
def main():
    #torch.manual_seed(42)
    # ------------
    # args
    # ------------
    parser = ArgumentParser()
    parser.add_argument('--auto_lr',
                        type=U.str2bool,
                        default=False,
                        help="Auto lr finder")
    parser.add_argument('--learning_rate', type=float, default=10e-4)
    parser.add_argument('--Loss', type=str, default='KL')
    parser.add_argument('--gamma',
                        type=float,
                        default=0.5,
                        help="gamma balance the two losses")
    parser.add_argument('--scheduler', type=U.str2bool, default=True)
    parser.add_argument('--wd', type=float, default=2e-4)
    parser.add_argument('--moment', type=float, default=0.9)
    parser.add_argument('--batch_size', default=5, type=int)
    parser.add_argument('--iter_every',
                        default=1,
                        type=int,
                        help="Accumulate compute graph for iter_size step")
    parser.add_argument('--n_epochs', default=10, type=int)
    parser.add_argument('--model',
                        default='DLV3',
                        type=str,
                        help="FCN or DLV3 model")
    parser.add_argument('--pretrained',
                        default=False,
                        type=U.str2bool,
                        help="Use pretrained pytorch model")
    parser.add_argument('--eval_angle', default=True, type=U.str2bool,help=\
        "If true, it'll eval the model with different angle input size")
    parser.add_argument('--eval_every',
                        default=30,
                        type=int,
                        help="Eval all input rotation angle every n step")
    parser.add_argument('--rotate',
                        default=False,
                        type=U.str2bool,
                        help="Use random rotation as data augmentation")
    parser.add_argument('--angle_max',
                        default=30,
                        type=int,
                        help="Max angle rotation of input image")
    parser.add_argument('--size_img',
                        default=520,
                        type=int,
                        help="Size of input images")
    parser.add_argument('--size_crop',
                        default=480,
                        type=int,
                        help="Size of crop image during training")
    parser.add_argument('--nw',
                        default=0,
                        type=int,
                        help="Num workers for the data loader")
    parser.add_argument('--pm',
                        default=True,
                        type=U.str2bool,
                        help="Pin memory for the dataloader")
    parser.add_argument('--gpu',
                        default=0,
                        type=int,
                        help="Wich gpu to select for training")
    parser.add_argument(
        '--rot_cpu',
        default=False,
        type=U.str2bool,
        help="Apply rotation on the cpu (Help to use less gpu memory)")
    parser.add_argument('--benchmark',
                        default=False,
                        type=U.str2bool,
                        help="enable or disable backends.cudnn")
    parser.add_argument('--split',
                        default=True,
                        type=U.str2bool,
                        help="Split the dataset")
    parser.add_argument('--split_ratio',
                        default=0.3,
                        type=float,
                        help="Amount of data we used for training")
    parser.add_argument('--extra_coco', default=False, type=U.str2bool,\
         help="Use coco dataset as extra annotation for fully supervised training")
    parser.add_argument(
        '--multi_task',
        default=False,
        type=U.str2bool,
        help="Multi task training (same data for equiv and sup)")
    parser.add_argument('--dataroot_voc',
                        default='/share/DEEPLEARNING/datasets/voc2012',
                        type=str)
    parser.add_argument('--dataroot_sbd',
                        default='/share/DEEPLEARNING/datasets/sbd',
                        type=str)
    parser.add_argument('--dataroot_coco',
                        default='/share/DEEPLEARNING/datasets/coco',
                        type=str)
    parser.add_argument('--model_name',
                        type=str,
                        help="what name to use for saving")
    parser.add_argument('--save_dir', default='/data/save_model', type=str)
    parser.add_argument('--save_all_ep', default=False, type=U.str2bool,help=\
        "If true it'll save the model every epoch in save_dir")
    parser.add_argument('--save_best',
                        default=False,
                        type=U.str2bool,
                        help="If true will only save the best epoch model")
    parser.add_argument('--load_last_model',
                        default=False,
                        type=U.str2bool,
                        help="If it will load the last model saved with\
                                                                                    This parameters."
                        )
    args = parser.parse_args()
    # ------------
    # device
    # ------------
    device = torch.device(
        "cuda:" + str(args.gpu) if torch.cuda.is_available() else "cpu")
    print("device used:", device)
    # ------------
    # model
    # ------------

    # ------------
    # data
    # ------------
    if args.size_img < args.size_crop:
        raise Exception(
            'Cannot have size of input images less than size of crop')
    size_img = (args.size_img, args.size_img)
    size_crop = (args.size_crop, args.size_crop)
    train_dataset_VOC = mdset.VOCSegmentation(args.dataroot_voc,year='2012', image_set='train', \
        download=True,rotate=args.rotate,size_img=size_img,size_crop=size_crop)
    val_dataset_VOC = mdset.VOCSegmentation(args.dataroot_voc,
                                            year='2012',
                                            image_set='val',
                                            download=True)
    train_dataset_SBD = mdset.SBDataset(args.dataroot_sbd, image_set='train_noval',mode='segmentation',\
        rotate=args.rotate,size_img=size_img,size_crop=size_crop)
    #COCO dataset
    if args.extra_coco:
        extra_COCO = cu.get_coco(args.dataroot_coco,
                                 'train',
                                 rotate=args.rotate,
                                 size_img=size_img,
                                 size_crop=size_crop)
    # Concatene dataset
    train_dataset_unsup = tud.ConcatDataset(
        [train_dataset_VOC, train_dataset_SBD])

    # Split dataset
    split = args.split
    if split == True:
        train_dataset_sup = U.split_dataset(train_dataset_unsup,
                                            args.split_ratio)
    else:
        train_dataset_sup = train_dataset_unsup
    # Multi task ?
    if args.multi_task:
        train_dataset_unsup = train_dataset_sup

    # If extra coco concatene all dataset for unsupervised training
    if args.extra_coco:
        train_dataset_unsup = tud.ConcatDataset(
            [train_dataset_VOC, train_dataset_SBD, extra_COCO])

    # Print len datasets
    print("There is",len(train_dataset_sup),"images for supervised training",len(train_dataset_unsup),\
        "for equivariance loss and",len(val_dataset_VOC),"for validation")

    dataloader_train_sup = torch.utils.data.DataLoader(train_dataset_sup, batch_size=args.batch_size,num_workers=args.nw,\
        pin_memory=args.pm,shuffle=True,drop_last=True)
    dataloader_val = torch.utils.data.DataLoader(val_dataset_VOC,num_workers=args.nw,pin_memory=args.pm,\
        batch_size=args.batch_size)
    # ---------
    # Load model
    # ---------
    if args.load_last_model:
        model,save_dir = fbm.load_best_model(save_dir=args.save_dir,model_name=args.model_name,split=args.split,\
            split_ratio=args.split_ratio,batch_size =args.batch_size,rotate=args.rotate)
        print("Training will continue from this file.", save_dir)
    else:
        save_dir = U.create_save_directory(
            args.save_dir)  # Create a new save directory
        if args.model.upper() == 'FCN':
            model = models.segmentation.fcn_resnet101(
                pretrained=args.pretrained)
        elif args.model.upper() == 'DLV3':
            print('DEEPLAB MODEL')
            model = models.segmentation.deeplabv3_resnet101(
                pretrained=args.pretrained)
        else:
            raise Exception('model must be "FCN" or "DLV3"')
        model.to(device)

    # ------------
    # save
    # ------------
    print('model will be saved in', save_dir)
    U.save_hparams(args, save_dir)
    # ------------
    # training
    # ------------
    # Auto lr finding
    #if args.auto_lr==True:

    criterion_supervised = nn.CrossEntropyLoss(
        ignore_index=21)  # On ignore la classe border.
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.learning_rate,
                                momentum=args.moment,
                                weight_decay=args.wd)
    ev.train_rot_equiv(model,args.n_epochs,dataloader_train_sup,train_dataset_unsup,dataloader_val,criterion_supervised,optimizer,\
        scheduler=args.scheduler,Loss=args.Loss,gamma=args.gamma,batch_size=args.batch_size,iter_every=args.iter_every,save_folder=save_dir,\
            model_name=args.model_name,benchmark=args.benchmark,angle_max=args.angle_max,size_img=args.size_img,\
        eval_every=args.eval_every,save_all_ep=args.save_all_ep,dataroot_voc=args.dataroot_voc,save_best=args.save_best\
            ,rot_cpu=args.rot_cpu,device=device)

    # Final evaluation
    """
예제 #5
0
def main():
    #torch.manual_seed(42)

    # ------------
    # args
    # ------------
    parser = ArgumentParser()

    # Model and eval
    # rot_equiv_lc.pt
    # fcn_fully_sup_lc.pt
    parser.add_argument('--model_name',
                        default='equiv_dlv3.pt',
                        type=str,
                        help="Model name")
    parser.add_argument(
        '--model_dir',
        default='/share/homes/karmimy/equiv/save_model/rot_equiv/72',
        type=str,
        help="Model name")
    parser.add_argument('--expe', default='72', type=str, help="3")
    args = parser.parse_args()

    # DATASETS
    dataroot_landcover = '/share/DEEPLEARNING/datasets/landcover'
    dataroot_voc = '/share/DEEPLEARNING/datasets/voc2012'

    model_dir = args.model_dir  # Saved model dir
    expe = args.expe
    model_name = args.model_name
    folder_model = model_dir
    #folder_model = join(model_dir,expe)

    nw = 4
    pm = True
    # GPU
    gpu = 0
    # EVAL PARAMETERS
    bs = 1

    # DEVICE
    # Decide which device we want to run on
    device = torch.device("cuda:" +
                          str(gpu) if torch.cuda.is_available() else "cpu")
    print("device :", device)

    model = torch.load(join(folder_model, model_name), map_location=device)
    #test_dataset = mdset.LandscapeDataset(dataroot_landcover,image_set="test")
    l_angles = [30, 0]
    #l_angles = [330,340,350,0,10,20,30]
    l_iou = []
    l_iou_bg = []
    l_iou_c1 = []
    l_iou_c2 = []
    l_iou_c3 = []
    for angle in l_angles:
        test_dataset = mdset.VOCSegmentation(dataroot_voc,
                                             year='2012',
                                             image_set='val',
                                             download=False,
                                             fixing_rotate=True,
                                             angle_fix=angle)
        dataloader_val = torch.utils.data.DataLoader(test_dataset,num_workers=nw,pin_memory=pm,\
            batch_size=bs)
        state = eval_model(model,
                           dataloader_val,
                           device=device,
                           num_classes=21)
        m_iou = state.metrics['mean IoU']
        iou = state.metrics['IoU']
        acc = state.metrics['accuracy']
        loss = state.metrics['CE Loss']
        l_iou.append(round(m_iou, 3))
        print('EVAL FOR ANGLE', angle, ': IoU', m_iou, 'ACC:', acc, 'LOSS',
              loss)
        print('IoU All classes', iou)
        #l_iou_bg.append(float(iou[0]))
        #l_iou_c1.append(float(iou[1]))
        #l_iou_c2.append(float(iou[2]))
        #l_iou_c3.append(float(iou[3]))
    l_iou.append(l_iou[0])

    print('L_IOU', l_iou)
def main():
    #torch.manual_seed(42)
    # ------------
    # args
    # ------------
    parser = ArgumentParser()
    parser.add_argument('--auto_lr',
                        type=U.str2bool,
                        default=False,
                        help="Auto lr finder")
    parser.add_argument('--learning_rate', type=float, default=10e-4)
    parser.add_argument('--Loss', type=str, default='KL')
    parser.add_argument('--gamma',
                        type=float,
                        default=0.5,
                        help="gamma balance the two losses")
    parser.add_argument('--scheduler', type=U.str2bool, default=False)
    parser.add_argument('--wd', type=float, default=2e-4)
    parser.add_argument('--moment', type=float, default=0.9)
    parser.add_argument('--batch_size', default=5, type=int)
    parser.add_argument('--n_epochs', default=10, type=int)
    parser.add_argument('--model',
                        default='FCN',
                        type=str,
                        help="FCN or DLV3 model")
    parser.add_argument('--pretrained',
                        default=False,
                        type=U.str2bool,
                        help="Use pretrained pytorch model")
    parser.add_argument('--eval_every',
                        default=30,
                        type=int,
                        help="Eval all input rotation angle every n step")
    parser.add_argument('--rotate',
                        default=False,
                        type=U.str2bool,
                        help="Use random rotation as data augmentation")
    parser.add_argument('--scale',
                        default=True,
                        type=U.str2bool,
                        help="Use scale as data augmentation")
    parser.add_argument('--scale_factor',
                        default=30,
                        type=float,
                        nargs='+',
                        help="Scale image between min*size - max*size")
    parser.add_argument('--size_img',
                        default=520,
                        type=int,
                        help="Size of input images")
    parser.add_argument('--size_crop',
                        default=480,
                        type=int,
                        help="Size of crop image during training")
    parser.add_argument('--nw',
                        default=0,
                        type=int,
                        help="Num workers for the data loader")
    parser.add_argument('--pm',
                        default=True,
                        type=U.str2bool,
                        help="Pin memory for the dataloader")
    parser.add_argument('--gpu',
                        default=0,
                        type=int,
                        help="Wich gpu to select for training")
    parser.add_argument('--benchmark',
                        default=False,
                        type=U.str2bool,
                        help="enable or disable backends.cudnn")
    parser.add_argument('--split',
                        default=True,
                        type=U.str2bool,
                        help="Split the dataset")
    parser.add_argument('--split_ratio',
                        default=0.3,
                        type=float,
                        help="Amount of data we used for training")
    parser.add_argument(
        '--multi_task',
        default=False,
        type=U.str2bool,
        help="Multi task training (same data for equiv and sup)")
    parser.add_argument('--dataroot_voc', default='/data/voc2012', type=str)
    parser.add_argument('--dataroot_sbd', default='/data/sbd', type=str)
    parser.add_argument('--model_name',
                        type=str,
                        help="what name to use for saving")
    parser.add_argument('--save_dir', default='/data/save_model', type=str)
    parser.add_argument('--save_all_ep', default=False, type=U.str2bool,help=\
        "If true it'll save the model every epoch in save_dir")
    parser.add_argument('--save_best',
                        default=False,
                        type=U.str2bool,
                        help="If true will only save the best epoch model")
    args = parser.parse_args()
    # ------------
    # save
    # ------------
    save_dir = U.create_save_directory(args.save_dir)
    print('model will be saved in', save_dir)
    U.save_hparams(args, save_dir)
    # ------------
    # device
    # ------------
    device = torch.device(
        "cuda:" + str(args.gpu) if torch.cuda.is_available() else "cpu")
    print("device used:", device)
    # ------------
    # model
    # ------------

    if args.model.upper() == 'FCN':
        model = models.segmentation.fcn_resnet101(pretrained=args.pretrained)
    elif args.model.upper() == 'DLV3':
        model = models.segmentation.deeplabv3_resnet101(
            pretrained=args.pretrained)
    else:
        raise Exception('model must be "FCN" or "DLV3"')
    model.to(device)
    # ------------
    # data
    # ------------
    if args.size_img < args.size_crop:
        raise Exception(
            'Cannot have size of input images less than size of crop')
    size_img = (args.size_img, args.size_img)
    size_crop = (args.size_crop, args.size_crop)
    train_dataset_VOC = mdset.VOCSegmentation(args.dataroot_voc,year='2012', image_set='train', \
        download=True,rotate=args.rotate,scale=args.scale,size_img=size_img,size_crop=size_crop)
    val_dataset_VOC = mdset.VOCSegmentation(args.dataroot_voc,
                                            year='2012',
                                            image_set='val',
                                            download=True)
    train_dataset_SBD = mdset.SBDataset(args.dataroot_sbd, image_set='train_noval',mode='segmentation',\
        rotate=args.rotate,scale=args.scale,size_img=size_img,size_crop=size_crop)
    # Concatene dataset
    train_dataset_unsup = tud.ConcatDataset(
        [train_dataset_VOC, train_dataset_SBD])

    # Split dataset
    split = args.split
    if split == True:
        train_dataset_sup = U.split_dataset(train_dataset_unsup,
                                            args.split_ratio)
    # Multi task ?
    if args.multi_task:
        train_dataset_unsup = train_dataset_sup
    # Print len datasets
    print("There is",len(train_dataset_sup),"images for supervised training",len(train_dataset_unsup),\
        "for equivariance loss and",len(val_dataset_VOC),"for validation")

    dataloader_train_sup = torch.utils.data.DataLoader(train_dataset_sup, batch_size=args.batch_size,num_workers=args.nw,\
        pin_memory=args.pm,shuffle=True,drop_last=True)
    dataloader_val = torch.utils.data.DataLoader(val_dataset_VOC,num_workers=args.nw,pin_memory=args.pm,\
        batch_size=args.batch_size)
    # Decide which device we want to run on

    # ------------
    # training
    # ------------
    # Auto lr finding
    #if args.auto_lr==True:
    scale_factor = (0.2, 0.8)
    criterion_supervised = nn.CrossEntropyLoss(
        ignore_index=21)  # On ignore la classe border.
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.learning_rate,
                                momentum=args.moment,
                                weight_decay=args.wd)
    ev.train_scale_equiv(model,args.n_epochs,dataloader_train_sup,train_dataset_unsup,dataloader_val,criterion_supervised,optimizer,\
        scheduler=args.scheduler,Loss=args.Loss,gamma=args.gamma,batch_size=args.batch_size,save_folder=save_dir,\
            model_name=args.model_name,benchmark=args.benchmark,scale_factor = scale_factor,\
                size_img=args.size_img,save_all_ep=args.save_all_ep,dataroot_voc=args.dataroot_voc,\
                    save_best=args.save_best,device=device)