Example #1
0
    def __init__(self, argv):
        self.args = argv
        self.nclass = 5
        self.gpu_ids = [0]
        self.cuda = True
        self.crop_size = 321

        # Define network
        self.model = DeepLab(num_classes=self.nclass,
                             backbone='resnet',
                             output_stride=16,
                             sync_bn=False,
                             freeze_bn=False)

        # Load weights
        #modelDir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
        #modelPath = os.path.join(modelDir, 'finalModels/MM_PM_Sag_model.pth.tar')
        modelPath = '/software/models/MM_PM_Sag_model.pth.tar'
        model = torch.load(modelPath)

        if self.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()
            self.model.module.load_state_dict(model['state_dict'])
        else:
            self.model.load_state_dict(model['state_dict'])
        print('Loaded model weights')
Example #2
0
    def __init__(self,model_path):
        self.lr = 0.007
        self.mean = (0.485, 0.456, 0.406)
        self.std = (0.229, 0.224, 0.225)
        if DFPE_test.cuda and len(DFPE_test.gpu_ids) > 1:
            self.sync_bn = True
        else:
            self.sync_bn = False
        self.model = DFPENet(num_classes=6,backbone='resnet',output_stride=16,sync_bn=self.sync_bn,freeze_bn=False)
        self.train_params = [{'params': self.model.get_1x_lr_params(), 'lr': self.lr},
                        {'params': self.model.get_10x_lr_params(), 'lr': self.lr * 10}]
        self.optimizer = torch.optim.SGD(self.train_params, momentum=0.9,weight_decay=5e-4, nesterov=False)
        if DFPE_test.cuda:
            self.model = self.model.cuda()
            self.model = torch.nn.DataParallel(self.model, device_ids=DFPE_test.gpu_ids)
            patch_replication_callback(self.model)
        #if DFPE_test.cuda:
        checkpoint = torch.load(model_path)
        #else:
            #device = torch.device("cpu")
            #checkpoint = torch.load(model_path,map_location=device)
        epoch = checkpoint['epoch']
        if DFPE_test.cuda:
            self.model.load_state_dict(checkpoint['state_dict'])
        else:

            self.model.load_state_dict(checkpoint['state_dict'])
        
        self.optimizer.load_state_dict(checkpoint['optimizer'])
        self.best_pred = checkpoint['best_pred']
Example #3
0
    def __init__(self, argv):
        self.args = argv
        self.nclass = 10
        self.gpu_ids = [0]
        self.dataset = 'heart'
        #whether to use cuda (can be an additional argument i.e. sys.argv[3][5:])
        self.cuda = True
        self.crop_size = 513

        #load model and weights here

        # Define network
        self.model = DeepLab(num_classes=self.nclass,
                             backbone='resnet',
                             output_stride=16,
                             sync_bn=False,
                             freeze_bn=False)

        model3 = torch.load('/software/heartModels/heart_peri_model.pth.tar'
                            )  #requires nclass = 10
        if self.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()
            self.model.module.load_state_dict(model3['state_dict'])
        else:
            self.model.load_state_dict(model3['state_dict'])
        print('finish tesing loading weight')
Example #4
0
    def __init__(self, args):
        self.args = args

        # configure datasetpath
        self.baseroot = None
        if args.dataset == 'pascal':
            self.baseroot = '/path/to/your/VOCdevkit/VOC2012/'
        ''' no support,
        # if you want train on these
        # you need modefy here 
        # refer to /dataloader/datasets/pascal to 
        #implement the corresponding constructor to dataset

        elif args.dataset == 'cityscapes':
            self.baseroot = '/path/to/your/cityscapes/'
        elif args.dataset == 'sbd':
            self.baseroot = '/path/to/your/sbd/'
        elif args.dataset == 'coco':
            self.baseroot = '/path/to/your/coco/'
        '''

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.test_loader, self.nclass = make_data_loader(
            self.baseroot, args, **kwargs)

        #define net model
        self.model = DeepLab(num_classes=self.nclass,
                             backbone=args.backbone,
                             output_stride=args.out_stride,
                             sync_bn=False,
                             freeze_bn=False).cuda()

        # self.model.module.load_state_dict(torch.load('./model_best.pth.tar', map_location='cpu'))
        self.evaluator = Evaluator(self.nclass)

        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        self.best_pred = 0.0

        if not os.path.isfile(args.resume):
            raise RuntimeError("=> no checkpoint found at '{}'".format(
                args.resume))
        checkpoint = torch.load(args.resume)
        args.start_epoch = checkpoint['epoch']
        if args.cuda:
            self.model.module.load_state_dict(checkpoint['state_dict'])
        else:
            self.model.load_state_dict(checkpoint['state_dict'])

        self.best_pred = checkpoint['best_pred']
        print("=> loaded checkpoint '{}' (epoch {})".format(
            args.resume, checkpoint['epoch']))
Example #5
0
    def __init__(self, args, ori_img_lst, init_mask_lst):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()
        self.ori_img_lst = ori_img_lst
        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.test_loader, self.nclass = make_data_loader_demo(
            args, args.test_folder, ori_img_lst, init_mask_lst, **kwargs)

        # Define network
        model = DeepLab(num_classes=self.nclass,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn,
                        use_iou=args.use_maskiou)

        self.model = model

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'],
                                                  strict=False)
            else:
                self.model.load_state_dict(checkpoint['state_dict'],
                                           strict=False)
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0
Example #6
0
    def __init__(self, para):
        self.args = para

        # Define Saver
        self.saver = Saver(para)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        self.train_loader, self.val_loader, self.test_loader, self.nclass = dataloader(
            para)

        # Define network
        model = DeepLab(num_classes=self.nclass,
                        backbone=para.backbone,
                        output_stride=para.out_stride,
                        sync_bn=para.sync_bn,
                        freeze_bn=para.freeze_bn)

        train_params = [{
            'params': model.get_1x_lr_params(),
            'lr': para.lr
        }, {
            'params': model.get_10x_lr_params(),
            'lr': para.lr * 10
        }]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params,
                                    momentum=para.momentum,
                                    weight_decay=para.weight_decay,
                                    nesterov=para.nesterov)

        # Define Criterion

        self.criterion = SegmentationLosses(
            weight=None, cuda=True).build_loss(mode=para.loss_type)
        self.model, self.optimizer = model, optimizer

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(para.lr_scheduler, para.lr, para.epochs,
                                      len(self.train_loader))

        self.model = torch.nn.DataParallel(self.model)
        patch_replication_callback(self.model)
        self.model = self.model.cuda()
        # Resuming checkpoint
        self.best_pred = 0.0
Example #7
0
    def __init__(self, args):
        self.args = args
        self.vs = vs(args.nice)

        #Dataloader
        kwargs = {"num_workers": args.workers, 'pin_memory': True}
        if self.args.dataset == 'bdd':
            _, _, self.test_loader, self.nclass = make_data_loader(
                args, **kwargs)
        else:  #self.args.dataset == 'nice':
            self.test_loader, self.nclass = make_data_loader(args, **kwargs)
        #else:
        #	raise NotImplementedError

        ### Load models
        #backs = ["resnet", "resnet152"]
        backs = ["resnet", "ibn", "resnet152"]
        check = './ckpt'
        checks = ["herbrand.pth.tar", "ign85.12.pth.tar", "r152_85.20.pth.tar"]
        self.models = []
        self.M = len(backs)
        # define models
        for i in range(self.M):
            model = DeepLab(num_classes=self.nclass,
                            backbone=backs[i],
                            output_stride=16,
                            Norm=gn,
                            freeze_bn=False)
            self.models.append(model)
            self.models[i] = torch.nn.DataParallel(
                self.models[i], device_ids=self.args.gpu_ids)
            patch_replication_callback(self.models[i])
            self.models[i] = self.models[i].cuda()
        # load checkpoints
        for i in range(self.M):
            resume = os.path.join(check, checks[i])
            if not os.path.isfile(resume):
                raise RuntimeError(
                    "=> no checkpoint found at '{}'".format(resume))
            checkpoint = torch.load(resume)
            dicts = checkpoint['state_dict']
            model_dict = {}
            state_dict = self.models[i].module.state_dict()
            for k, v in dicts.items():
                if k in state_dict:
                    model_dict[k] = v
            state_dict.update(model_dict)
            self.models[i].module.load_state_dict(state_dict)
            print("{} loaded successfully".format(checks[i]))
Example #8
0
    def __init__(self, args):
        self.args = args
        if self.args.dataset == 'cityscapes':
            self.nclass = 19

        self.model = AutoDeeplab(num_classes=self.nclass,
                                 num_layers=12,
                                 filter_multiplier=self.args.filter_multiplier,
                                 block_multiplier=args.block_multiplier,
                                 step=args.step)
        # Using cuda
        if args.cuda:
            if (torch.cuda.device_count() > 1 or args.load_parallel):
                self.model = torch.nn.DataParallel(self.model.cuda())
                patch_replication_callback(self.model)
            self.model = self.model.cuda()
            print('cuda finished')

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']

            # if the weights are wrapped in module object we have to clean it
            if args.clean_module:
                self.model.load_state_dict(checkpoint['state_dict'])
                state_dict = checkpoint['state_dict']
                new_state_dict = OrderedDict()
                for k, v in state_dict.items():
                    name = k[7:]  # remove 'module.' of dataparallel
                    new_state_dict[name] = v
                self.model.load_state_dict(new_state_dict)

            else:
                if (torch.cuda.device_count() > 1 or args.load_parallel):
                    self.model.module.load_state_dict(checkpoint['state_dict'])
                else:
                    self.model.load_state_dict(checkpoint['state_dict'])

        self.decoder = Decoder(self.model.alphas, self.model.bottom_betas,
                               self.model.betas8, self.model.betas16,
                               self.model.top_betas, args.block_multiplier,
                               args.step)
Example #9
0
    def __init__(self, args):
        self.args = args

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        val_set = pascal.VOCSegmentation(args, split='val')
        self.nclass = val_set.NUM_CLASSES
        self.val_loader = DataLoader(val_set,
                                     batch_size=args.batch_size,
                                     shuffle=False,
                                     **kwargs)

        # Define network
        self.model = DeepLab(num_classes=self.nclass,
                             backbone=args.backbone,
                             output_stride=args.out_stride,
                             sync_bn=args.sync_bn,
                             freeze_bn=args.freeze_bn)
        self.criterion = SegmentationLosses(
            weight=None, cuda=args.cuda).build_loss(mode=args.loss_type)

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)

        # Using cuda
        if args.cuda:
            print('device_ids', self.args.gpu_ids)
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
Example #10
0
    def __init__(self, args):
        self.args = args
        
        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)

        # # Define network
        model = DeepLab(num_classes=self.nclass,
                        backbone=args.backbone)
        # model = DeepLab(num_classes=self.nclass,
        #                 backbone=args.backbone,
        #                 output_stride=args.out_stride,
        #                 sync_bn=args.sync_bn,
        #                 freeze_bn=args.freeze_bn)
        self.model = model
        
        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume))
            if args.cuda:
                checkpoint = torch.load(args.resume)
            else:
                checkpoint = torch.load(args.resume, map_location='cpu')
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])

            # if not args.ft:
            #     self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
Example #11
0
    def __init__(self, args):
        self.args = args

        # Define Saver
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        # kwargs = {'num_workers': 1, 'pin_memory': True}
        self.test_loader, self.nclass = get_test_data_loader(args)
        #_, _, self.test_loader, self.nclass = make_data_loader(self.args)
        # Define network
        if args.model == 'deeplab':
            model = DeepLab(num_classes=self.nclass,
                            backbone=args.backbone,
                            output_stride=args.out_stride,
                            sync_bn=args.sync_bn,
                            freeze_bn=args.freeze_bn)
        elif args.model == 'unet':
            model = UNet(n_classes=self.nclass, n_channels=3)

        self.model = model

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0
Example #12
0
    def __init__(self, args):
        self.args = args

        # Define network
        model = DeepLab(num_classes=32,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn)

        #         self.criterion = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model = model

        # Define Evaluator
        self.evaluator = Evaluator(32)

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        time_start = time.time()
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0
Example #13
0
    def initialize_model(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(self.args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()
        # Define Dataloader
        kwargs = {'num_worker': self.args.worker, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            self.args, **kwargs)
        # Define network
        model = DeepLab(num_classes=self.nclass,
                        backbone=self.args.backbone,
                        output_stride=self.args.out_stride,
                        sync_bn=self.args.sync_bn,
                        freeze_bn=self.args.freeze_bn)
        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)

        # Using cuda
        if self.args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        if not os.path.isfile(self.args.resume):
            raise RuntimeError("=> no checkpoint found at '{}'".format(
                self.args.resume))
        checkpoint = torch.load(self.args.resume)
        self.args.start_epoch = checkpoint['epoch']
        if self.args.cuda:
            self.model.module.load_state_dict(checkpoint['state_dict'])
        else:
            self.model.load_state_dict(checkpoint['state_dict'])
        self.model.eval()
        self.evaluator.reset()
def get_model(nclass, args):

    model = DeepLab(num_classes=nclass,
                    backbone=args.backbone,
                    output_stride=args.out_stride,
                    sync_bn=args.sync_bn,
                    freeze_bn=args.freeze_bn)

    # Using cuda
    if args.cuda:
        model = torch.nn.DataParallel(model, device_ids=args.gpu_ids)
        patch_replication_callback(model)
        model = model.cuda()

    checkpoint = torch.load(args.resume)
    if args.cuda:
        model.module.load_state_dict(checkpoint['state_dict'])
    else:
        model.load_state_dict(checkpoint['state_dict'])
    print("=> loaded checkpoint '{}' (epoch {})".format(
        args.resume, checkpoint['epoch']))

    return model
    def __init__(self, args):
        self.args = args

        # Define network
        model = DeepLab(num_classes=args.num_classes,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn)

        self.model = model
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        _, self.valid_loader = make_data_loader(args, **kwargs)
        self.pred_remap = args.pred_remap
        self.gt_remap = args.gt_remap

        # Define Evaluator
        self.evaluator = Evaluator(args.eval_num_classes)

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
Example #16
0
    def __init__(self, cfgfile):
        self.args = parse_cfg(cfgfile)

        self.nclass = int(self.args['nclass'])

        model = DeepLab(num_classes=self.nclass,
                        backbone=self.args['backbone'],
                        output_stride=int(self.args['out_stride']),
                        sync_bn=bool(self.args['sync_bn']),
                        freeze_bn=bool(self.args['freeze_bn']))

        weight = None

        self.criterion = SegmentationLosses(
            weight=weight, cuda=True).build_loss(mode=self.args['loss_type'])

        self.model = model
        self.evaluator = Evaluator(self.nclass)

        # Using cuda

        self.model = self.model.cuda()
        self.model = torch.nn.DataParallel(self.model, device_ids=[0])
        patch_replication_callback(self.model)
        self.resume = self.args['resume']

        # Resuming checkpoint
        if self.resume is not None:
            if not os.path.isfile(self.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    self.resume))
            checkpoint = torch.load(self.resume)

            self.model.module.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                self.resume, checkpoint['epoch']))
Example #17
0
                            freeze_bn=args.freeze_bn,
                            depth=args.depth)

if args.model == 'unet':
    student_model = Unet(args, num_classes=args.nclass, depth=args.depth)
    teacher_model = Unet(args, num_classes=args.nclass, depth=args.depth)

new_params = student_model.state_dict().copy()
student_model.load_state_dict(new_params)
teacher_model.load_state_dict(new_params)

# Using cuda

student_model = torch.nn.DataParallel(student_model)
teacher_model = torch.nn.DataParallel(teacher_model)
patch_replication_callback(student_model)
patch_replication_callback(teacher_model)
student_model = student_model.cuda()
teacher_model = teacher_model.cuda()

for name, param in teacher_model.named_parameters():
    param.requires_grad = False

if args.model == 'deeplabv3' or args.model == 'unet':
    # train_params = [{'params': model.get_10x_lr_params(), 'lr': args.lr},
    # ]

    # Define Optimizer
    # optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()),lr=args.lr, momentum=args.momentum,
    #                             weight_decay=args.weight_decay, nesterov=args.nesterov)
    optimizer = torch.optim.Adam(student_model.parameters(),
Example #18
0
def main():
    parser = argparse.ArgumentParser(
        description="PyTorch DeeplabV3Plus Training")
    parser.add_argument('--backbone',
                        type=str,
                        default='resnet',
                        choices=['resnet', 'xception', 'drn', 'mobilenet'],
                        help='backbone name (default: resnet)')
    parser.add_argument('--out-stride',
                        type=int,
                        default=16,
                        help='network output stride (default: 8)')
    parser.add_argument('--dataset',
                        type=str,
                        default='dfc19dsm',
                        choices=['pascal', 'coco', 'cityscapes', 'dfc19dsm'],
                        help='dataset name (default: pascal)')
    parser.add_argument('--use-sbd',
                        action='store_true',
                        default=False,
                        help='whether to use SBD dataset (default: True)')
    parser.add_argument('--workers',
                        type=int,
                        default=1,
                        metavar='N',
                        help='dataloader threads')
    parser.add_argument('--base-size',
                        type=int,
                        default=1024,
                        help='base image size')
    parser.add_argument('--crop-size',
                        type=int,
                        default=1024,
                        help='crop image size')
    parser.add_argument('--sync-bn',
                        type=bool,
                        default=None,
                        help='whether to use sync bn (default: auto)')
    parser.add_argument(
        '--freeze-bn',
        type=bool,
        default=True,
        help='whether to freeze bn parameters (default: False)')

    parser.add_argument('--batch-size',
                        type=int,
                        default=2,
                        metavar='N',
                        help='input batch size for \
                                training (default: auto)')
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=4,
                        metavar='N',
                        help='input batch size for \
                                testing (default: auto)')
    parser.add_argument('--crf',
                        type=bool,
                        default=False,
                        help='crf post-processing')

    # cuda, seed and logging
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')

    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
                        help='random seed (default: 1)')
    # checking point
    parser.add_argument('--resume',
                        type=str,
                        default='/data/PreTrainedModel/DSM/3446.pth.tar',
                        help='put the path to resuming file if needed')
    parser.add_argument('--checkname',
                        type=str,
                        default=None,
                        help='set the checkpoint name')

    args = parser.parse_args()
    args.cuda = not args.no_cuda and torch.cuda.is_available()

    if args.batch_size is None:
        args.batch_size = 2

    if args.test_batch_size is None:
        args.test_batch_size = args.batch_size

    kwargs = {'num_workers': args.workers, 'pin_memory': True}
    test_loader, nclass = make_test_loader(args, **kwargs)

    # Define network
    model = DeepLab(num_classes=1,
                    backbone=args.backbone,
                    output_stride=args.out_stride,
                    sync_bn=args.sync_bn,
                    freeze_bn=args.freeze_bn)

    # Using cuda
    model = torch.nn.DataParallel(model)
    patch_replication_callback(model)
    model = model.cuda()
    output_dir = '/data/yonghao.xu/DFC2019/track1/Result/Split-Val/'
    crfoutput_dir = '/data/yonghao.xu/DFC2019/track1/Result/Split-Valcrf/'
    if os.path.exists(output_dir) == False:
        os.mkdir(output_dir)
    if os.path.exists(crfoutput_dir) == False:
        os.mkdir(crfoutput_dir)
    # Resuming checkpoint

    if args.resume is not None:
        if not os.path.isfile(args.resume):
            raise RuntimeError("=> no checkpoint found at '{}'".format(
                args.resume))
        checkpoint = torch.load(args.resume)

        if args.cuda:
            model.module.load_state_dict(checkpoint['state_dict'])
        else:
            model.load_state_dict(checkpoint['state_dict'])

        print("=> loaded checkpoint '{}' (epoch {})".format(
            args.resume, checkpoint['epoch']))

    model.eval()

    tbar = tqdm(test_loader, desc='\r')

    #%%
    for i, sample in enumerate(tbar):
        image, name = sample
        if args.cuda:
            image = image.cuda()
        with torch.no_grad():
            output = model(image)

        pred = output.data.cpu().numpy()
        image = image.cpu().numpy()

        if args.crf == False:
            for j in range(len(name)):
                DSMoutName = name[j].replace(IMG_FILE_STR, DEPTH_FILE_STR)
                DSMpred = pred[j].squeeze() - BIAS
                tifffile.imsave(os.path.join(output_dir, DSMoutName),
                                DSMpred,
                                compress=6)
                #tifffile.imsave(os.path.join(crfoutput_dir, DSMoutName), DSMpred, compress=6)

        else:
            for j in range(len(name)):

                #image = CNNMap
                im = image[j].transpose((1, 2, 0))

                softmax = pred[j].reshape((1, -1))

                # The input should be the negative of the logarithm of probability values
                # Look up the definition of the softmax_to_unary for more information
                unary = softmax_to_unary(softmax)

                # The inputs should be C-continious -- we are using Cython wrapper
                unary = np.ascontiguousarray(unary)

                d = dcrf.DenseCRF(im.shape[0] * im.shape[1], 1)

                d.setUnaryEnergy(unary)

                # This potential penalizes small pieces of segmentation that are
                # spatially isolated -- enforces more spatially consistent segmentations
                feats = create_pairwise_gaussian(sdims=(3, 3),
                                                 shape=im.shape[:2])

                d.addPairwiseEnergy(feats,
                                    compat=3,
                                    kernel=dcrf.DIAG_KERNEL,
                                    normalization=dcrf.NORMALIZE_SYMMETRIC)

                # This creates the color-dependent features --
                # because the segmentation that we get from CNN are too coarse
                # and we can use local color features to refine them
                feats = create_pairwise_bilateral(sdims=(80, 80),
                                                  schan=[13],
                                                  img=im,
                                                  chdim=2)

                d.addPairwiseEnergy(feats,
                                    compat=1,
                                    kernel=dcrf.DIAG_KERNEL,
                                    normalization=dcrf.NORMALIZE_SYMMETRIC)
                Q = d.inference(3)

                res = np.squeeze(Q).reshape((im.shape[0], im.shape[1]))

                crfoutName = name[j].replace(IMG_FILE_STR, DEPTH_FILE_STR)
                crfpred = res - BIAS
                tifffile.imsave(os.path.join(crfoutput_dir, crfoutName),
                                crfpred,
                                compress=6)
Example #19
0
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        #self.train_loader1, self.train_loader2, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)
        self.train_loader1, self.train_loader2, self.val_loader,  self.nclass = make_data_loader(args, **kwargs)
        
        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(Path.db_root_dir(args.dataset), args.dataset+'_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)

        # Define network
        model = AutoDeeplab (self.nclass, 12, self.criterion, crop_size=self.args.crop_size)
        optimizer = torch.optim.SGD(
                model.parameters(),
                args.lr,
                momentum=args.momentum,
                weight_decay=args.weight_decay
            )
        self.model, self.optimizer = model, optimizer

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()
            print ('cuda finished')


        # Define Optimizer


        self.model, self.optimizer = model, optimizer

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
                                            args.epochs, len(self.train_loader1))

        self.architect = Architect (self.model, args)
        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0
Example #20
0
    def __init__(self, args):
        self.args = args
        self.vs = Vs(args.dataset)

        # Define Dataloader
        kwargs = {"num_workers": args.workers, "pin_memory": True}
        (
            self.train_loader,
            self.val_loader,
            self.test_loader,
            self.nclass,
        ) = make_data_loader(args, **kwargs)

        if self.args.norm == "gn":
            norm = gn
        elif self.args.norm == "bn":
            if self.args.sync_bn:
                norm = syncbn
            else:
                norm = bn
        elif self.args.norm == "abn":
            if self.args.sync_bn:
                norm = syncabn(self.args.gpu_ids)
            else:
                norm = abn
        else:
            print("Please check the norm.")
            exit()

        # Define network
        if self.args.model == "deeplabv3+":
            model = DeepLab(args=self.args,
                            num_classes=self.nclass,
                            freeze_bn=args.freeze_bn)
        elif self.args.model == "deeplabv3":
            model = DeepLabv3(
                Norm=args.norm,
                backbone=args.backbone,
                output_stride=args.out_stride,
                num_classes=self.nclass,
                freeze_bn=args.freeze_bn,
            )
        elif self.args.model == "fpn":
            model = FPN(args=args, num_classes=self.nclass)

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(
                Path.db_root_dir(args.dataset),
                args.dataset + "_classes_weights.npy")
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset,
                                                  self.train_loader,
                                                  self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(
            weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model = model

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint["epoch"]
            if args.cuda:
                self.model.module.load_state_dict(checkpoint["state_dict"])
            else:
                self.model.load_state_dict(checkpoint["state_dict"])
            self.best_pred = checkpoint["best_pred"]
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint["epoch"]))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0
Example #21
0
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args, **kwargs)

        cell_path = os.path.join(args.saved_arch_path, 'genotype.npy')
        network_path_space = os.path.join(args.saved_arch_path,
                                          'network_path_space.npy')

        new_cell_arch = np.load(cell_path)
        new_network_arch = np.load(network_path_space)

        # Define network
        model = newModel(network_arch=new_network_arch,
                         cell_arch=new_cell_arch,
                         num_classes=self.nclass,
                         num_layers=12)
        #                        output_stride=args.out_stride,
        #                        sync_bn=args.sync_bn,
        #                        freeze_bn=args.freeze_bn)
        self.decoder = Decoder(self.nclass, 'autodeeplab', args, False)
        # TODO: look into these
        # TODO: ALSO look into different param groups as done int deeplab below
        #        train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr},
        #                        {'params': model.get_10x_lr_params(), 'lr': args.lr * 10}]
        #
        train_params = [{'params': model.parameters(), 'lr': args.lr}]
        # Define Optimizer
        optimizer = torch.optim.SGD(train_params,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay,
                                    nesterov=args.nesterov)

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(
                Path.db_root_dir(args.dataset),
                args.dataset + '_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset,
                                                  self.train_loader,
                                                  self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(
            weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(
            args.lr_scheduler, args.lr, args.epochs,
            len(self.train_loader))  #TODO: use min_lr ?

        # TODO: Figure out if len(self.train_loader) should be devided by two ? in other module as well
        # Using cuda
        if args.cuda:
            if (torch.cuda.device_count() > 1 or args.load_parallel):
                self.model = torch.nn.DataParallel(self.model.cuda())
                patch_replication_callback(self.model)
            self.model = self.model.cuda()
            print('cuda finished')

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']

            # if the weights are wrapped in module object we have to clean it
            if args.clean_module:
                self.model.load_state_dict(checkpoint['state_dict'])
                state_dict = checkpoint['state_dict']
                new_state_dict = OrderedDict()
                for k, v in state_dict.items():
                    name = k[7:]  # remove 'module.' of dataparallel
                    new_state_dict[name] = v
                self.model.load_state_dict(new_state_dict)

            else:
                if (torch.cuda.device_count() > 1 or args.load_parallel):
                    self.model.module.load_state_dict(checkpoint['state_dict'])
                else:
                    self.model.load_state_dict(checkpoint['state_dict'])

            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0
Example #22
0
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()
        self.use_amp = True if (APEX_AVAILABLE and args.use_amp) else False
        self.opt_level = args.opt_level

        kwargs = {
            'num_workers': args.workers,
            'pin_memory': True,
            'drop_last': True
        }
        self.train_loaderA, self.train_loaderB, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args, **kwargs)

        if args.use_balanced_weights:
            classes_weights_path = os.path.join(
                Path.db_root_dir(args.dataset),
                args.dataset + '_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                raise NotImplementedError
                #if so, which trainloader to use?
                # weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(
            weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)

        # Define network
        model = AutoDeeplab(self.nclass, 12, self.criterion,
                            self.args.filter_multiplier,
                            self.args.block_multiplier, self.args.step)
        optimizer = torch.optim.SGD(model.weight_parameters(),
                                    args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)

        self.model, self.optimizer = model, optimizer

        self.architect_optimizer = torch.optim.Adam(
            self.model.arch_parameters(),
            lr=args.arch_lr,
            betas=(0.9, 0.999),
            weight_decay=args.arch_weight_decay)

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler,
                                      args.lr,
                                      args.epochs,
                                      len(self.train_loaderA),
                                      min_lr=args.min_lr)
        # TODO: Figure out if len(self.train_loader) should be devided by two ? in other module as well
        # Using cuda
        if args.cuda:
            self.model = self.model.cuda()

        # mixed precision
        if self.use_amp and args.cuda:
            keep_batchnorm_fp32 = True if (self.opt_level == 'O2'
                                           or self.opt_level == 'O3') else None

            # fix for current pytorch version with opt_level 'O1'
            if self.opt_level == 'O1' and torch.__version__ < '1.3':
                for module in self.model.modules():
                    if isinstance(module,
                                  torch.nn.modules.batchnorm._BatchNorm):
                        # Hack to fix BN fprop without affine transformation
                        if module.weight is None:
                            module.weight = torch.nn.Parameter(
                                torch.ones(module.running_var.shape,
                                           dtype=module.running_var.dtype,
                                           device=module.running_var.device),
                                requires_grad=False)
                        if module.bias is None:
                            module.bias = torch.nn.Parameter(
                                torch.zeros(module.running_var.shape,
                                            dtype=module.running_var.dtype,
                                            device=module.running_var.device),
                                requires_grad=False)

            # print(keep_batchnorm_fp32)
            self.model, [self.optimizer,
                         self.architect_optimizer] = amp.initialize(
                             self.model,
                             [self.optimizer, self.architect_optimizer],
                             opt_level=self.opt_level,
                             keep_batchnorm_fp32=keep_batchnorm_fp32,
                             loss_scale="dynamic")

            print('cuda finished')

        # Using data parallel
        if args.cuda and len(self.args.gpu_ids) > 1:
            if self.opt_level == 'O2' or self.opt_level == 'O3':
                print(
                    'currently cannot run with nn.DataParallel and optimization level',
                    self.opt_level)
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            print('training on multiple-GPUs')

        #checkpoint = torch.load(args.resume)
        #print('about to load state_dict')
        #self.model.load_state_dict(checkpoint['state_dict'])
        #print('model loaded')
        #sys.exit()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']

            # if the weights are wrapped in module object we have to clean it
            if args.clean_module:
                self.model.load_state_dict(checkpoint['state_dict'])
                state_dict = checkpoint['state_dict']
                new_state_dict = OrderedDict()
                for k, v in state_dict.items():
                    name = k[7:]  # remove 'module.' of dataparallel
                    new_state_dict[name] = v
                # self.model.load_state_dict(new_state_dict)
                copy_state_dict(self.model.state_dict(), new_state_dict)

            else:
                if torch.cuda.device_count() > 1 or args.load_parallel:
                    # self.model.module.load_state_dict(checkpoint['state_dict'])
                    copy_state_dict(self.model.module.state_dict(),
                                    checkpoint['state_dict'])
                else:
                    # self.model.load_state_dict(checkpoint['state_dict'])
                    copy_state_dict(self.model.state_dict(),
                                    checkpoint['state_dict'])

            if not args.ft:
                # self.optimizer.load_state_dict(checkpoint['optimizer'])
                copy_state_dict(self.optimizer.state_dict(),
                                checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()
        
        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)

        # Define network
        model = DeepLab(num_classes=self.nclass,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn)

        train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr},
                        {'params': model.get_10x_lr_params(), 'lr': args.lr * 10}]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params, momentum=args.momentum,
                                    weight_decay=args.weight_decay, nesterov=args.nesterov)

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(Path.db_root_dir(args.dataset), args.dataset+'_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer
        
        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
                                            args.epochs, len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)

        # Define network
        if args.sync_bn == True:
            BN = SynchronizedBatchNorm2d
        else:
            BN = nn.BatchNorm2d
        ### deeplabV3 start ###
        self.backbone_model = MobileNetV2(output_stride = args.out_stride,
                            BatchNorm = BN)
        self.assp_model = ASPP(backbone = args.backbone,
                          output_stride = args.out_stride,
                          BatchNorm = BN)
        self.y_model = Decoder(num_classes = self.nclass,
                          backbone = args.backbone,
                          BatchNorm = BN)
        ### deeplabV3 end ###
        self.d_model = DomainClassifer(backbone = args.backbone,
                                  BatchNorm = BN)
        f_params = list(self.backbone_model.parameters()) + list(self.assp_model.parameters())
        y_params = list(self.y_model.parameters())
        d_params = list(self.d_model.parameters())

        # Define Optimizer
        if args.optimizer == 'SGD':
            self.task_optimizer = torch.optim.SGD(f_params+y_params, lr= args.lr,
                                             momentum=args.momentum,
                                             weight_decay=args.weight_decay, nesterov=args.nesterov)
            self.d_optimizer = torch.optim.SGD(d_params, lr= args.lr,
                                          momentum=args.momentum,
                                          weight_decay=args.weight_decay, nesterov=args.nesterov)
            self.d_inv_optimizer = torch.optim.SGD(f_params, lr= args.lr,
                                          momentum=args.momentum,
                                          weight_decay=args.weight_decay, nesterov=args.nesterov)
            self.c_optimizer = torch.optim.SGD(f_params+y_params, lr= args.lr,
                                          momentum=args.momentum,
                                          weight_decay=args.weight_decay, nesterov=args.nesterov)
        elif args.optimizer == 'Adam':
            self.task_optimizer = torch.optim.Adam(f_params + y_params, lr=args.lr)
            self.d_optimizer = torch.optim.Adam(d_params, lr=args.lr)
            self.d_inv_optimizer = torch.optim.Adam(f_params, lr=args.lr)
            self.c_optimizer = torch.optim.Adam(f_params+y_params, lr=args.lr)
        else:
            raise NotImplementedError

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = 'dataloders\\datasets\\'+args.dataset + '_classes_weights.npy'
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(self.train_loader, self.nclass, classes_weights_path, self.args.dataset)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.task_loss = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.domain_loss = DomainLosses(cuda=args.cuda).build_loss()
        self.ca_loss = ''

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)

        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
                                      args.epochs, len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.backbone_model = torch.nn.DataParallel(self.backbone_model, device_ids=self.args.gpu_ids)
            self.assp_model = torch.nn.DataParallel(self.assp_model, device_ids=self.args.gpu_ids)
            self.y_model = torch.nn.DataParallel(self.y_model, device_ids=self.args.gpu_ids)
            self.d_model = torch.nn.DataParallel(self.d_model, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.backbone_model)
            patch_replication_callback(self.assp_model)
            patch_replication_callback(self.y_model)
            patch_replication_callback(self.d_model)
            self.backbone_model = self.backbone_model.cuda()
            self.assp_model = self.assp_model.cuda()
            self.y_model = self.y_model.cuda()
            self.d_model = self.d_model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.backbone_model.module.load_state_dict(checkpoint['backbone_model_state_dict'])
                self.assp_model.module.load_state_dict(checkpoint['assp_model_state_dict'])
                self.y_model.module.load_state_dict(checkpoint['y_model_state_dict'])
                self.d_model.module.load_state_dict(checkpoint['d_model_state_dict'])
            else:
                self.backbone_model.load_state_dict(checkpoint['backbone_model_state_dict'])
                self.assp_model.load_state_dict(checkpoint['assp_model_state_dict'])
                self.y_model.load_state_dict(checkpoint['y_model_state_dict'])
                self.d_model.load_state_dict(checkpoint['d_model_state_dict'])
            if not args.ft:
                self.task_optimizer.load_state_dict(checkpoint['task_optimizer'])
                self.d_optimizer.load_state_dict(checkpoint['d_optimizer'])
                self.d_inv_optimizer.load_state_dict(checkpoint['d_inv_optimizer'])
                self.c_optimizer.load_state_dict(checkpoint['c_optimizer'])
            if self.args.dataset == 'gtav':
                self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()
        self.use_amp = True if (APEX_AVAILABLE and args.use_amp) else False
        self.opt_level = args.opt_level

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)

        cell_path_d = os.path.join(args.saved_arch_path, 'genotype_device.npy')
        cell_path_c = os.path.join(args.saved_arch_path, 'genotype_cloud.npy')
        network_path_space = os.path.join(args.saved_arch_path, 'network_path_space.npy')

        new_cell_arch_d = np.load(cell_path_d)
        new_cell_arch_c = np.load(cell_path_c)
        new_network_arch = np.load(network_path_space)
        new_network_arch = [1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2]
        # Define network
        model = new_cloud_Model(network_arch= new_network_arch,
                         cell_arch_d = new_cell_arch_d,
                         cell_arch_c = new_cell_arch_c,
                         num_classes=self.nclass,
                         device_num_layers=6)

        # TODO: look into these
        # TODO: ALSO look into different param groups as done int deeplab below
#        train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr},
#                        {'params': model.get_10x_lr_params(), 'lr': args.lr * 10}]
#
        train_params = [{'params': model.parameters(), 'lr': args.lr}]
        # Define Optimizer
        optimizer = torch.optim.SGD(train_params, momentum=args.momentum,
                                    weight_decay=args.weight_decay, nesterov=args.nesterov)

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(Path.db_root_dir(args.dataset), args.dataset + '_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer

        # Define Evaluator
        self.evaluator_device = Evaluator(self.nclass)
        self.evaluator_cloud = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
                                      args.epochs, len(self.train_loader)) #TODO: use min_lr ?

        # Using cuda
        if self.use_amp and args.cuda:
            keep_batchnorm_fp32 = True if (self.opt_level == 'O2' or self.opt_level == 'O3') else None

            # fix for current pytorch version with opt_level 'O1'
            if self.opt_level == 'O1' and torch.__version__ < '1.3':
                for module in self.model.modules():
                    if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
                        # Hack to fix BN fprop without affine transformation
                        if module.weight is None:
                            module.weight = torch.nn.Parameter(
                                torch.ones(module.running_var.shape, dtype=module.running_var.dtype,
                                           device=module.running_var.device), requires_grad=False)
                        if module.bias is None:
                            module.bias = torch.nn.Parameter(
                                torch.zeros(module.running_var.shape, dtype=module.running_var.dtype,
                                            device=module.running_var.device), requires_grad=False)

            # print(keep_batchnorm_fp32)
            self.model, [self.optimizer, self.architect_optimizer] = amp.initialize(
                self.model, [self.optimizer, self.architect_optimizer], opt_level=self.opt_level,
                keep_batchnorm_fp32=keep_batchnorm_fp32, loss_scale="dynamic")

            print('cuda finished')

        # Using data parallel
        if args.cuda and len(self.args.gpu_ids) >1:
            if self.opt_level == 'O2' or self.opt_level == 'O3':
                print('currently cannot run with nn.DataParallel and optimization level', self.opt_level)
            self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            print('training on multiple-GPUs')

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']

            # if the weights are wrapped in module object we have to clean it
            if args.clean_module:
                self.model.load_state_dict(checkpoint['state_dict'])
                state_dict = checkpoint['state_dict']
                new_state_dict = OrderedDict()
                for k, v in state_dict.items():
                    name = k[7:]  # remove 'module.' of dataparallel
                    new_state_dict[name] = v
                self.model.load_state_dict(new_state_dict)

            else:
                if (torch.cuda.device_count() > 1 or args.load_parallel):
                    self.model.module.load_state_dict(checkpoint['state_dict'])
                else:
                    self.model.load_state_dict(checkpoint['state_dict'])

            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0
Example #26
0
import cv2
import torch.nn.functional as F
from dataloaders import utils
from res18test.resnet_dilated import *
from doc.deeplab_resnet import *

#import pydensecrf.densecrf as dcrf
import os
model = DeepLab(num_classes=2,
                backbone='mobilenet',
                output_stride=32,
                sync_bn=False,
                freeze_bn=False)

model = torch.nn.DataParallel(model, device_ids=[0, 1])
patch_replication_callback(model)
model = model.cuda()
checkpoint = torch.load(
    "/home/xupeihan/deeplab/run/vocdetection/mb_finAL/experiment_2/checkpoint.pth.tar"
)
model.module.load_state_dict(checkpoint['state_dict'])
#model.load_state_dict(checkpoint['state_dict'])

model.eval()
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)

time1 = time.time()

PATH = '/home/xupeihan/deeplab/img/'
#PATH = '/mnt/disk2/xupeihan/seg_code/res-101-1-25mix/img/'
Example #27
0
    def __init__(self, args):
        self.args = args


        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)

        # Define network
        if args.sync_bn == True:
            BN = SynchronizedBatchNorm2d
        else:
            BN = nn.BatchNorm2d
        ### deeplabV3 start ###
        self.backbone_model = MobileNetV2(output_stride = args.out_stride,
                            BatchNorm = BN)
        self.assp_model = ASPP(backbone = args.backbone,
                          output_stride = args.out_stride,
                          BatchNorm = BN)
        self.y_model = Decoder(num_classes = self.nclass,
                          backbone = args.backbone,
                          BatchNorm = BN)
        ### deeplabV3 end ###
        self.d_model = DomainClassifer(backbone = args.backbone,
                                  BatchNorm = BN)
        f_params = list(self.backbone_model.parameters()) + list(self.assp_model.parameters())
        y_params = list(self.y_model.parameters())
        d_params = list(self.d_model.parameters())


        # Using cuda
        if args.cuda:
            self.backbone_model = torch.nn.DataParallel(self.backbone_model, device_ids=self.args.gpu_ids)
            self.assp_model = torch.nn.DataParallel(self.assp_model, device_ids=self.args.gpu_ids)
            self.y_model = torch.nn.DataParallel(self.y_model, device_ids=self.args.gpu_ids)
            self.d_model = torch.nn.DataParallel(self.d_model, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.backbone_model)
            patch_replication_callback(self.assp_model)
            patch_replication_callback(self.y_model)
            patch_replication_callback(self.d_model)
            self.backbone_model = self.backbone_model.cuda()
            self.assp_model = self.assp_model.cuda()
            self.y_model = self.y_model.cuda()
            self.d_model = self.d_model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.backbone_model.module.load_state_dict(checkpoint['backbone_model_state_dict'])
                self.assp_model.module.load_state_dict(checkpoint['assp_model_state_dict'])
                self.y_model.module.load_state_dict(checkpoint['y_model_state_dict'])
                self.d_model.module.load_state_dict(checkpoint['d_model_state_dict'])
            else:
                self.backbone_model.load_state_dict(checkpoint['backbone_model_state_dict'])
                self.assp_model.load_state_dict(checkpoint['assp_model_state_dict'])
                self.y_model.load_state_dict(checkpoint['y_model_state_dict'])
                self.d_model.load_state_dict(checkpoint['d_model_state_dict'])
            '''if not args.ft:
                self.task_optimizer.load_state_dict(checkpoint['task_optimizer'])
                self.d_optimizer.load_state_dict(checkpoint['d_optimizer'])
                self.d_inv_optimizer.load_state_dict(checkpoint['d_inv_optimizer'])
                self.c_optimizer.load_state_dict(checkpoint['c_optimizer'])'''
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else: 
            print('No Resuming Checkpoint Given')
            raise NotImplementedError

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args, **kwargs)

        # Define network
        model = DeepLab(num_classes=self.nclass,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn)

        train_params = [{
            'params': model.get_1x_lr_params(),
            'lr': args.lr
        }, {
            'params': model.get_10x_lr_params(),
            'lr': args.lr * 10
        }]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay,
                                    nesterov=args.nesterov)

        # Define Criterion

        self.criterion = SegmentationLosses(cuda=args.cuda)
        self.model, self.optimizer = model, optimizer
        self.contexts = TemporalContexts(history_len=5)

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs,
                                      len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))

            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning or in validation/test mode
        if args.ft or args.mode == "val" or args.mode == "test":
            args.start_epoch = 0
            self.best_pred = 0.0
Example #29
0
    def __init__(self, args):
        self.args = args

        # Generate .npy file for dataloader
        self.img_process(args)

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args, **kwargs)

        # Define network
        model = getattr(modeling, args.model_name)(pretrained=args.pretrained)

        # Define Optimizer
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay,
                                    nesterov=args.nesterov)
        # train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr},
        #                 {'params': model.get_10x_lr_params(), 'lr': args.lr * 10}]

        # Define Criterion
        self.criterion = SegmentationLosses(
            weight=None, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs,
                                      len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args, **kwargs)

        # Define network
        model = DeepLab(num_classes=self.nclass,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn)

        print(self.nclass, args.backbone, args.out_stride, args.sync_bn,
              args.freeze_bn)
        #2 resnet 16 False False

        train_params = [{
            'params': model.get_1x_lr_params(),
            'lr': args.lr
        }, {
            'params': model.get_10x_lr_params(),
            'lr': args.lr * 10
        }]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay,
                                    nesterov=args.nesterov)

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(
                Path.db_root_dir(args.dataset),
                args.dataset + '_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset,
                                                  self.train_loader,
                                                  self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(
            weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs,
                                      len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume, map_location='cpu')
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0
Example #31
0
def main():
    parser = argparse.ArgumentParser(
        description="PyTorch DeeplabV3Plus Predicting")

    parser.add_argument('--model',
                        type=str,
                        default='deeplabv3',
                        choices=['deeplabv3', 'unet'],
                        help='model name (default: deeplabv3)')
    parser.add_argument('--backbone',
                        type=str,
                        default='resnet',
                        choices=[
                            'resnet', 'resnext', 'senet', 'senext', 'cbamnet',
                            'cbamnext'
                        ],
                        help='backbone name (default: resnet)')
    parser.add_argument('--depth',
                        type=int,
                        default=None,
                        help='to choos which model depth(default: 50)')
    parser.add_argument('--out-stride',
                        type=int,
                        default=16,
                        help='network output stride (default: 8)')
    parser.add_argument('--dataset',
                        type=str,
                        default=None,
                        choices=['val', 'tes'],
                        help='dataset name (default: pascal)')
    parser.add_argument('--nclass',
                        type=int,
                        default=8,
                        help='number of classes')
    parser.add_argument('--bsz', type=int, default=256, help='base image size')
    parser.add_argument('--csz', type=int, default=224, help='crop image size')
    parser.add_argument('--rsz',
                        type=int,
                        default=256,
                        help='resample image size')
    parser.add_argument('--oly-s1',
                        action='store_true',
                        default=False,
                        help='only use s1 data')
    parser.add_argument('--oly-s2',
                        action='store_true',
                        default=False,
                        help='only use s2 data')
    parser.add_argument('--scale',
                        type=str,
                        default='std',
                        choices=['std', 'norm'],
                        help='how to scale in preprocessing')

    parser.add_argument('--rgb',
                        action='store_true',
                        default=False,
                        help='data augmentation')
    parser.add_argument('--denoise',
                        action='store_true',
                        default=False,
                        help='data augmentation')
    parser.add_argument('--dehaze',
                        action='store_true',
                        default=False,
                        help='data augmentation')
    parser.add_argument('--rule',
                        type=str,
                        default=None,
                        choices=['dw_jiu', 'dw_new'],
                        help='label filter')

    parser.add_argument('--sync-bn',
                        type=bool,
                        default=None,
                        help='whether to use sync bn (default: auto)')
    parser.add_argument(
        '--freeze-bn',
        type=bool,
        default=True,
        help='whether to freeze bn parameters (default: False)')
    parser.add_argument('--batch-size',
                        type=int,
                        default=2,
                        metavar='N',
                        help='input batch size for training (default: auto)')
    parser.add_argument('--pre-batch-size',
                        type=int,
                        default=4,
                        metavar='N',
                        help='input batch size for testing (default: auto)')
    parser.add_argument('--crf',
                        action='store_true',
                        default=False,
                        help='crf postprocessing')
    parser.add_argument('--mode',
                        type=str,
                        default='none',
                        choices=['soft', 'hard', 'none'],
                        help='voting method')

    # cuda, seed and logging
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')
    # checking point
    parser.add_argument('--resume',
                        type=str,
                        default='/data/PreTrainedModel/DSM/3446.pth.tar',
                        help='put the path to resuming file if needed')
    # output dir
    parser.add_argument('--dir',
                        type=str,
                        default=None,
                        help='folder of prediction of test dataset')
    parser.add_argument('--export',
                        type=str,
                        default=None,
                        choices=['image', 'prob'],
                        help='export data category')

    args = parser.parse_args()
    args.cuda = not args.no_cuda and torch.cuda.is_available()

    if args.batch_size is None:
        args.batch_size = 1

    if args.pre_batch_size is None:
        args.pre_batch_size = args.batch_size

    if args.oly_s1 and not args.oly_s2:
        print('Only use s1 SAR data!')
    elif not args.oly_s1 and args.oly_s2 and not args.rgb:
        print('Only use s2 MSI data!')
    elif not args.oly_s1 and args.oly_s2 and args.rgb:
        print('Only use s2 RGB data!')
    elif not args.oly_s1 and not args.oly_s2 and not args.rgb:
        print('Using s1 and s2 data in the same time!')
    elif not args.oly_s1 and not args.oly_s2 and args.rgb:
        print('Using s1 and s2 rgb data in the same time!')
    else:
        raise NotImplementedError

    outputdir = args.dir + 'output/'
    visdir = args.dir + 'vis/'
    probdir = args.dir + 'prob/'

    if not os.path.exists(outputdir):
        os.makedirs(outputdir)

    if not os.path.exists(visdir):
        os.makedirs(visdir)

    if not os.path.exists(probdir):
        os.makedirs(probdir)

    ########################################## model ###########################################

    # Define network
    if args.model == 'deeplabv3':
        model = DeepLab(args,
                        num_classes=args.nclass,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn,
                        depth=args.depth)

    if args.model == 'unet':
        model = Unet(args, num_classes=args.nclass, depth=args.depth)

    # Using cuda
    model = torch.nn.DataParallel(model)
    patch_replication_callback(model)
    model = model.cuda()

    if args.resume is not None:
        if not os.path.isfile(args.resume):
            raise RuntimeError("=> no checkpoint found at '{}'".format(
                args.resume))
        checkpoint = torch.load(args.resume)

        if args.cuda:
            model.module.load_state_dict(checkpoint['teacher_state_dict'])
        else:
            model.load_state_dict(checkpoint['teacher_state_dict'])

        print("=> loaded checkpoint '{}' (epoch {})".format(
            args.resume, checkpoint['epoch']))

    model.eval()

    ###################################### prepare data ########################################

    if args.dataset == 'val':
        base_dir = '/data/PublicData/DF2020/val/'
        s1_pre, s2_pre, lc_pre = load_valdata(base_dir)
        print('{} val images for prediction'.format(s1_pre.shape[0]))

    if args.dataset == 'tes':

        base_dir = '/data/PublicData/DF2020/test_track1/'
        s1_pre, s2_pre, lc_pre = load_tesdata(base_dir)

        print('{} tes images for prediction'.format(s1_pre.shape[0]))

    pre_dataset = DF2020(args, s1_pre, s2_pre, lc_pre, split='pre')

    pre_loader = DataLoader(dataset=pre_dataset,
                            batch_size=args.pre_batch_size,
                            shuffle=False)

    tbar = tqdm(pre_loader, desc='\r')

    evaluator = Evaluator(10)

    pre_prob = np.zeros([s1_pre.shape[0], args.nclass, 256, 256])

    for i, (x1, x2, y, index) in enumerate(tbar):

        x1_ori = np.array(x1).transpose(0, 2, 3, 1)  # N,H,W,C
        x2_ori = np.array(x2).transpose(0, 2, 3, 1)

        # x1
        data1_1 = x1_ori.copy()
        data1_2 = np.rot90(data1_1, 1, (1, 2)).copy()
        data1_3 = np.rot90(data1_1, 2, (1, 2)).copy()
        data1_4 = np.rot90(data1_1, 3, (1, 2)).copy()

        x1_hinv = x1_ori[:, :, ::-1, :].copy()

        data1_5 = x1_hinv.copy()
        data1_6 = np.rot90(x1_hinv, 1, (1, 2)).copy()
        data1_7 = np.rot90(x1_hinv, 2, (1, 2)).copy()
        data1_8 = np.rot90(x1_hinv, 3, (1, 2)).copy()

        data1_9 = np.concatenate([
            data1_1, data1_2, data1_3, data1_4, data1_5, data1_6, data1_7,
            data1_8
        ],
                                 axis=0)
        # x2
        data2_1 = x2_ori.copy()
        data2_2 = np.rot90(data2_1, 1, (1, 2)).copy()
        data2_3 = np.rot90(data2_1, 2, (1, 2)).copy()
        data2_4 = np.rot90(data2_1, 3, (1, 2)).copy()

        x2_hinv = x2_ori[:, :, ::-1, :].copy()

        data2_5 = x2_hinv.copy()
        data2_6 = np.rot90(x2_hinv, 1, (1, 2)).copy()
        data2_7 = np.rot90(x2_hinv, 2, (1, 2)).copy()
        data2_8 = np.rot90(x2_hinv, 3, (1, 2)).copy()

        data2_9 = np.concatenate([
            data2_1, data2_2, data2_3, data2_4, data2_5, data2_6, data2_7,
            data2_8
        ],
                                 axis=0)

        if args.cuda:
            data1_9 = data1_9.transpose(0, 3, 1, 2)  # (8,C,H,W)
            data2_9 = data2_9.transpose(0, 3, 1, 2)

            data1_9 = torch.from_numpy(data1_9).cuda()
            data2_9 = torch.from_numpy(data2_9).cuda()

        with torch.no_grad():

            _, _, output = model(data1_9, data2_9)

        output = F.softmax(output, dim=1)

        pred = output.data.cpu().numpy()

        if args.mode == 'hard':

            if not args.oly_s1 and args.crf:
                label_temp = np.zeros([8, pred.shape[-2], pred.shape[-1]])
                for j in range(pred.shape[0]):
                    _, label_temp[j] = CRF(
                        pred[j], data2_9[j, [3, 2, 1], :, :].cpu().numpy())
            else:
                label_temp = np.argmax(pred, axis=1)  # (8,H,W)

            #############反变换################

            label1 = label_temp[0].copy()  # (1,H,W)
            label2 = np.rot90(label_temp[1], -1, (0, 1)).copy()
            label3 = np.rot90(label_temp[2], -2, (0, 1)).copy()
            label4 = np.rot90(label_temp[3], -3, (0, 1)).copy()

            label5 = label_temp[4][:, ::-1].copy()
            label6 = np.rot90(label_temp[5], -1, (0, 1))[:, ::-1].copy()
            label7 = np.rot90(label_temp[6], -2, (0, 1))[:, ::-1].copy()
            label8 = np.rot90(label_temp[7], -3, (0, 1))[:, ::-1].copy()

            ##############投票################

            label_mat_3 = np.concatenate([
                np.expand_dims(label1, 0),
                np.expand_dims(label2, 0),
                np.expand_dims(label3, 0),
                np.expand_dims(label4, 0),
                np.expand_dims(label5, 0),
                np.expand_dims(label6, 0),
                np.expand_dims(label7, 0),
                np.expand_dims(label8, 0)
            ],
                                         axis=0)  # (8,H,W)

            label = np.zeros((256 * 256, ))
            label_mat_3 = np.reshape(label_mat_3, [8, 256 * 256])
            for m in range(256 * 256):
                temp = label_mat_3[:, m]
                label[m] = np.argmax(np.bincount(temp))

            label = np.reshape(label, [256, 256])

        elif args.mode == 'soft':

            label_temp = pred

            #############反变换################

            label1 = label_temp[0].copy()  # (10,H,W)
            label2 = np.rot90(label_temp[1], -1, (1, 2)).copy()
            label3 = np.rot90(label_temp[2], -2, (1, 2)).copy()
            label4 = np.rot90(label_temp[3], -3, (1, 2)).copy()

            label5 = label_temp[4][:, :, ::-1].copy()
            label6 = np.rot90(label_temp[5], -1, (1, 2))[:, :, ::-1].copy()
            label7 = np.rot90(label_temp[6], -2, (1, 2))[:, :, ::-1].copy()
            label8 = np.rot90(label_temp[7], -3, (1, 2))[:, :, ::-1].copy()

            ##############投票################

            label_mat = label1 + label2 + label3 + label4 + label5 + label6 + label7 + label8  # (8,H,W)

            if not args.oly_s1 and args.crf:
                if args.rgb:
                    _, label = CRF(label_mat / 8, x2.squeeze().numpy())
                else:
                    _, label = CRF(label_mat / 8,
                                   x2[:, [3, 2, 1], :, :].squeeze().numpy())
            else:
                label = np.argmax(label_mat, 0)

        elif args.mode == 'none':

            label_mat = pred[0, :, :, :]  # (8,H,W)

            if not args.oly_s1 and args.crf:
                if args.rgb:
                    _, label = CRF(label_mat / 8, x2.squeeze().numpy())
                else:
                    _, label = CRF(label_mat,
                                   x2[:, [3, 2, 1], :, :].squeeze().numpy())
            else:
                label = np.argmax(label_mat, 0)

        if args.export == 'image':

            h, w = label.shape

            label = label.reshape(-1)

            label = list(map(lambda x: net2templab[x], label))

            label = np.array(label)  # list->array

            label = label.reshape(-1, h, w)

            im = np.uint8(label + 1)

            filename = os.path.basename(lc_pre[index])
            former = filename.split('lc')[0]
            latter = filename.split('lc')[1]

            f = base_dir + 's2_0/' + former + 's2' + latter

            with rasterio.open(f) as patch:
                x2_img = patch.read(list(range(1, 14)))

            #print(x2_img.shape)

            NDWI, NDVI, NBI, MSI, BSI = Cal_INDEX(x2_img)

            y_tmp = y[0, 0, :, :].cpu().numpy().copy()

            # im[np.where(NDWI>0)]=10
            # label[np.where(NDWI>0)]=9

            im_tmp = im.copy()
            y_pre_tmp = label.copy()

            # grassland

            im[np.where((NDWI < 0) & (NDVI > 0.4) & (NDVI < 0.6)
                        & ((y_tmp == 2) | (y_tmp == 3) | (im_tmp == 5))
                        & (np.sum(NDWI > 0) < 2000))] = 4

            label[np.where((NDWI < 0) & (NDVI > 0.4) & (NDVI < 0.6)
                           & ((y_tmp == 2) | (y_tmp == 3) | (y_pre_tmp == 4))
                           & (np.sum(NDWI > 0) < 2000))] = 3

            # wetland

            im[np.where((NDVI > 0.6) & (NDVI < 0.75)
                        & ((y_tmp == 4) | (im_tmp == 4))
                        & (np.sum(NDWI > 0) > 4000))] = 5

            label[np.where((NDVI > 0.6) & (NDVI < 0.75)
                           & ((y_tmp == 4) | (y_pre_tmp == 3))
                           & (np.sum(NDWI > 0) > 4000))] = 4

            # forest

            im[np.where((NDWI < 0) & (NDVI > 0.75)
                        & ((y_tmp == 2) | (y_tmp == 0) | (im_tmp == 4)
                           | (im_tmp == 5)))] = 1
            label[np.where((NDWI < 0) & (NDVI > 0.75)
                           & ((y_tmp == 2) | (y_tmp == 0) | (y_pre_tmp == 3)
                              | (y_pre_tmp == 4)))] = 0

            im[np.where((im_tmp == 5) & (np.sum(NDWI > 0) < 1000))] = 1
            label[np.where((y_pre_tmp == 4) & (np.sum(NDWI > 0) < 1000))] = 0

            # cropland

            im[np.where((NDWI < 0) & (NDVI < 0.4) & (NDVI > 0.2) & (MSI > 1)
                        & (MSI < 1.5) & ((y_tmp == 5) | (y_tmp == 2)))] = 6
            label[np.where((NDWI < 0) & (NDVI < 0.4) & (NDVI > 0.2) & (MSI > 1)
                           & (MSI < 1.5) & ((y_tmp == 5) | (y_tmp == 2)))] = 5

            # urban

            im[np.where((NDWI < 0) & (NDVI < 0.2) & (NDVI > 0) & (BSI > -0.4)
                        & (y_tmp == 6))] = 7
            label[np.where((NDWI < 0) & (NDVI < 0.2) & (NDVI > 0)
                           & (BSI > -0.4) & (y_tmp == 6))] = 6

            # barren

            im[np.where((NDWI < 0) & (NBI > 750) & (NDVI < 0.4) & (NDVI > 0)
                        & (y_tmp != 5) & (y_tmp != 6) & (im_tmp != 6)
                        & (im_tmp != 7) & ((im_tmp == 4) | (im_tmp == 2)))] = 9
            label[np.where((NDWI < 0) & (NBI > 750) & (NDVI < 0.4) & (NDVI > 0)
                           & (y_tmp != 5) & (y_tmp != 6) & (y_pre_tmp != 5)
                           & (y_pre_tmp != 6)
                           & ((y_pre_tmp == 3) | (y_pre_tmp == 1)))] = 8

            # shrubland

            im[np.where((NDWI < 0) & (y_tmp == 1) & (im_tmp == 2))] = 2
            label[np.where((NDWI < 0) & (y_tmp == 1) & (y_pre_tmp == 1))] = 1

            imsave(outputdir + former + 'dfc' + latter, im)
            im_rgb = Image.fromarray(
                np.uint8(
                    DrawResult(label.reshape(-1), x2.shape[-2], x2.shape[-1])))
            im_rgb.save(visdir + former + 'dfc' + latter[:-4] + '_vis.png')

        elif args.export == 'prob':

            if args.mode == 'soft':

                label = label[np.newaxis, :, :]

                pred = label_mat / 8.0

            b, h, w = label.shape

            label = label.reshape(-1)

            label = list(map(lambda x: net2templab[x], label))

            label = np.array(label)  # list->array

            label = label.reshape(-1, h, w)

            pre_prob[index, :, :, :] = pred

        else:
            raise NotImplementedError

        if args.export == 'image':

            target = y[:, 0, :, :].cpu().numpy()  # batch_size * 256 * 256
            # # Add batch sample into evaluator
            evaluator.add_batch(target, label[np.newaxis, :, :])

        elif args.export == 'prob':

            target = y[:, 0, :, :].cpu().numpy()  # batch_size * 256 * 256
            # # Add batch sample into evaluator
            evaluator.add_batch(target, label)

        else:
            raise NotImplementedError

    AA = evaluator.pre_Pixel_Accuracy_Class()

    print('AVERAGE ACCURACY of {} DATASET: {}'.format(str(args.dataset), AA))

    print(
        'ACCURACY IN EACH CLASSES:',
        np.diag(evaluator.confusion_matrix) /
        evaluator.confusion_matrix.sum(axis=1))

    if args.export == 'image':

        print('IMAGE EXPORTED finished!')

    elif args.export == 'prob':

        pre_prob = pre_prob.astype('float32')

        np.save(
            probdir + 'DFC2020_tes_' + str(args.model) + '_' +
            str(args.backbone) + '_' + str(args.batch_size) + '_.npy',
            pre_prob)

        print('PROBABILITY EXPORTED finished!')

    print('Prediction finished!')