Beispiel #1
0
    def __init__(self, args):
        self.device = torch.device(args.device)
        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
        ])
        # dataset and dataloader
        data_kwargs = {
            'transform': input_transform,
            'base_size': args.base_size,
            'crop_size': args.crop_size
        }
        trainset = get_segmentation_dataset(args.dataset,
                                            split=args.train_split,
                                            mode='train',
                                            **data_kwargs)
        args.per_iter = len(trainset) // (args.num_gpus * args.batch_size)
        args.max_iter = args.epochs * args.per_iter
        if args.distributed:
            sampler = data.DistributedSampler(trainset)
        else:
            sampler = data.RandomSampler(trainset)
        train_sampler = data.sampler.BatchSampler(sampler, args.batch_size,
                                                  True)
        train_sampler = IterationBasedBatchSampler(
            train_sampler, num_iterations=args.max_iter)
        self.train_loader = data.DataLoader(trainset,
                                            batch_sampler=train_sampler,
                                            pin_memory=True,
                                            num_workers=args.workers)
        if not args.skip_eval or 0 < args.eval_epochs < args.epochs:
            valset = get_segmentation_dataset(args.dataset,
                                              split='val',
                                              mode='val',
                                              **data_kwargs)
            val_sampler = make_data_sampler(valset, False, args.distributed)
            val_batch_sampler = data.sampler.BatchSampler(
                val_sampler, args.test_batch_size, False)
            self.valid_loader = data.DataLoader(
                valset,
                batch_sampler=val_batch_sampler,
                num_workers=args.workers,
                pin_memory=True)

        # create network
        self.net = LEDNet(trainset.NUM_CLASS)

        if args.distributed:
            self.net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.net)
        self.net.to(self.device)
        # resume checkpoint if needed
        if args.resume is not None:
            if os.path.isfile(args.resume):
                self.net.load_state_dict(torch.load(args.resume))
            else:
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))

        # create criterion
        if args.ohem:
            min_kept = args.batch_size * args.crop_size**2 // 16
            self.criterion = OHEMSoftmaxCrossEntropyLoss(thresh=0.7,
                                                         min_kept=min_kept,
                                                         use_weight=False)
        else:
            self.criterion = MixSoftmaxCrossEntropyLoss()

        # optimizer and lr scheduling
        self.optimizer = optim.SGD(self.net.parameters(),
                                   lr=args.lr,
                                   momentum=args.momentum,
                                   weight_decay=args.weight_decay)
        self.scheduler = WarmupPolyLR(self.optimizer,
                                      T_max=args.max_iter,
                                      warmup_factor=args.warmup_factor,
                                      warmup_iters=args.warmup_iters,
                                      power=0.9)

        if args.distributed:
            self.net = torch.nn.parallel.DistributedDataParallel(
                self.net,
                device_ids=[args.local_rank],
                output_device=args.local_rank)

        # evaluation metrics
        self.metric = SegmentationMetric(trainset.num_class)
        self.args = args
Beispiel #2
0
    input_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
    ])

    data_kwargs = {
        'base_size': args.base_size,
        'crop_size': args.crop_size,
        'transform': input_transform
    }

    val_dataset = get_segmentation_dataset(args.dataset,
                                           split=args.split,
                                           mode=args.mode,
                                           **data_kwargs)
    sampler = make_data_sampler(val_dataset, False, distributed)
    batch_sampler = data.BatchSampler(sampler=sampler,
                                      batch_size=args.batch_size,
                                      drop_last=False)
    val_data = data.DataLoader(val_dataset,
                               shuffle=False,
                               batch_sampler=batch_sampler,
                               num_workers=args.num_workers)
    metric = SegmentationMetric(val_dataset.num_class)

    metric = validate(model, val_data, metric, device)
    ptutil.synchronize()
    pixAcc, mIoU = ptutil.accumulate_metric(metric)
    if ptutil.is_main_process():
        print('pixAcc: %.4f, mIoU: %.4f' % (pixAcc, mIoU))
Beispiel #3
0
    def __init__(self, args):
        self.device = torch.device(args.device)
        self.save_prefix = '_'.join((args.model, args.backbone, args.dataset))
        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
        ])
        # dataset and dataloader
        data_kwargs = {
            'transform': input_transform,
            'base_size': args.base_size,
            'crop_size': args.crop_size
        }
        trainset = get_segmentation_dataset(args.dataset,
                                            split=args.train_split,
                                            mode='train',
                                            **data_kwargs)
        args.per_iter = len(trainset) // (args.num_gpus * args.batch_size)
        args.max_iter = args.epochs * args.per_iter
        if args.distributed:
            sampler = data.DistributedSampler(trainset)
        else:
            sampler = data.RandomSampler(trainset)
        train_sampler = data.sampler.BatchSampler(sampler, args.batch_size,
                                                  True)
        train_sampler = IterationBasedBatchSampler(
            train_sampler, num_iterations=args.max_iter)
        self.train_loader = data.DataLoader(trainset,
                                            batch_sampler=train_sampler,
                                            pin_memory=True,
                                            num_workers=args.workers)
        if not args.skip_eval or 0 < args.eval_epochs < args.epochs:
            valset = get_segmentation_dataset(args.dataset,
                                              split='val',
                                              mode='val',
                                              **data_kwargs)
            val_sampler = make_data_sampler(valset, False, args.distributed)
            val_batch_sampler = data.sampler.BatchSampler(
                val_sampler, args.test_batch_size, False)
            self.valid_loader = data.DataLoader(
                valset,
                batch_sampler=val_batch_sampler,
                num_workers=args.workers,
                pin_memory=True)

        # create network
        if args.model_zoo is not None:
            self.net = get_model(args.model_zoo, pretrained=True)
        else:
            kwargs = {'oc': args.oc} if args.model == 'ocnet' else {}
            self.net = get_segmentation_model(model=args.model,
                                              dataset=args.dataset,
                                              backbone=args.backbone,
                                              aux=args.aux,
                                              dilated=args.dilated,
                                              jpu=args.jpu,
                                              crop_size=args.crop_size,
                                              **kwargs)
        if args.distributed:
            self.net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.net)
        self.net.to(self.device)
        # resume checkpoint if needed
        if args.resume is not None:
            if os.path.isfile(args.resume):
                self.net.load_state_dict(torch.load(args.resume))
            else:
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))

        # create criterion
        if args.ohem:
            min_kept = args.batch_size * args.crop_size**2 // 16
            self.criterion = OHEMSoftmaxCrossEntropyLoss(thresh=0.7,
                                                         min_kept=min_kept,
                                                         use_weight=False)
        else:
            self.criterion = MixSoftmaxCrossEntropyLoss(
                args.aux, aux_weight=args.aux_weight)

        # optimizer and lr scheduling
        params_list = [{
            'params': self.net.base1.parameters(),
            'lr': args.lr
        }, {
            'params': self.net.base2.parameters(),
            'lr': args.lr
        }, {
            'params': self.net.base3.parameters(),
            'lr': args.lr
        }]
        if hasattr(self.net, 'others'):
            for name in self.net.others:
                params_list.append({
                    'params':
                    getattr(self.net, name).parameters(),
                    'lr':
                    args.lr * 10
                })
        if hasattr(self.net, 'JPU'):
            params_list.append({
                'params': self.net.JPU.parameters(),
                'lr': args.lr * 10
            })
        self.optimizer = optim.SGD(params_list,
                                   lr=args.lr,
                                   momentum=args.momentum,
                                   weight_decay=args.weight_decay)
        self.scheduler = WarmupPolyLR(self.optimizer,
                                      T_max=args.max_iter,
                                      warmup_factor=args.warmup_factor,
                                      warmup_iters=args.warmup_iters,
                                      power=0.9)

        if args.distributed:
            self.net = torch.nn.parallel.DistributedDataParallel(
                self.net,
                device_ids=[args.local_rank],
                output_device=args.local_rank)

        # evaluation metrics
        self.metric = SegmentationMetric(trainset.num_class)
        self.args = args
Beispiel #4
0
    def __init__(self, config,args,logger):
        self.DISTRIBUTED,self.DEVICE = ptutil.init_environment(config,args)
        self.LR = config.TRAIN.LR * len(config.GPUS)
        self.device = torch.device(self.DEVICE)
        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
        ])
        # dataset and dataloader
        
        if config.DATASET.IMG_TRANSFORM:
            data_kwargs = {'transform':input_transform, 'base_size':config.DATASET.BASE_SIZE,
                           'crop_size':config.DATASET.CROP_SIZE}
        else:
            data_kwargs = {'transform':None, 'base_size':config.DATASET.BASE_SIZE,
                           'crop_size':config.DATASET.CROP_SIZE}
        trainset = get_segmentation_dataset(
            config.DATASET.NAME, split=config.TRAIN.TRAIN_SPLIT, mode='train', **data_kwargs)
        self.per_iter = len(trainset) // (len(config.GPUS) * config.TRAIN.BATCH_SIZE)
        self.max_iter = config.TRAIN.EPOCHS * self.per_iter
        if self.DISTRIBUTED:
            sampler = data.DistributedSampler(trainset)
        else:
            sampler = data.RandomSampler(trainset)
        train_sampler = data.sampler.BatchSampler(sampler, config.TRAIN.BATCH_SIZE, True)
        train_sampler = IterationBasedBatchSampler(train_sampler, num_iterations=self.max_iter)
        self.train_loader = data.DataLoader(trainset, batch_sampler=train_sampler, pin_memory=config.DATASET.PIN_MEMORY,
                                            num_workers=config.DATASET.WORKERS)
        if not config.TRAIN.SKIP_EVAL or 0 < config.TRAIN.EVAL_EPOCHS < config.TRAIN.EPOCHS:
            valset = get_segmentation_dataset(config.DATASET.NAME, split='val', mode='val', **data_kwargs)
            val_sampler = make_data_sampler(valset, False, self.DISTRIBUTED)
            val_batch_sampler = data.sampler.BatchSampler(val_sampler, config.TEST.TEST_BATCH_SIZE, False)
            self.valid_loader = data.DataLoader(valset, batch_sampler=val_batch_sampler,
                                                num_workers=config.DATASET.WORKERS, pin_memory=config.DATASET.PIN_MEMORY)
        # create network
        self.net = get_segmentation_model(config.MODEL.NAME,nclass=trainset.NUM_CLASS).cuda()
        if self.DISTRIBUTED:
            if config.TRAIN.MIXED_PRECISION:
                self.net = apex.parallel.convert_syncbn_model(self.net)
            else:
                self.net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.net)
        if config.TRAIN.RESUME != '':
            self.net = ptutil.model_resume(self.net,config.TRAIN.RESUME,logger).to(self.device)
        # self.net.to(self.device)
        assert config.TRAIN.SEG_LOSS in ('focalloss2d', 'mixsoftmaxcrossentropyohemloss', 'mixsoftmaxcrossentropy'), 'cannot support {}'.format(config.TRAIN.SEG_LOSS)
        if config.TRAIN.SEG_LOSS == 'focalloss2d':
            self.criterion = get_loss(config.TRAIN.SEG_LOSS,gamma=2., use_weight=False, size_average=True, ignore_index=config.DATASET.IGNORE_INDEX)
        elif config.TRAIN.SEG_LOSS == 'mixsoftmaxcrossentropyohemloss':
            min_kept = int(config.TRAIN.BATCH_SIZE // len(config.GPUS) * config.DATASET.CROP_SIZE ** 2 // 16)
            self.criterion = get_loss(config.TRAIN.SEG_LOSS,min_kept=min_kept,ignore_index =config.DATASET.IGNORE_INDEX).to(self.device)
        else:
            self.criterion = get_loss(config.TRAIN.SEG_LOSS,ignore_index=config.DATASET.IGNORE_INDEX)

        self.optimizer = optim.SGD(self.net.parameters(), lr=self.LR, momentum=config.TRAIN.MOMENTUM,
                                   weight_decay=config.TRAIN.WEIGHT_DECAY)
        self.scheduler = WarmupPolyLR(self.optimizer, T_max=self.max_iter, warmup_factor=config.TRAIN.WARMUP_FACTOR,
                                      warmup_iters=config.TRAIN.WARMUP_ITERS, power=0.9)
        # self.net.apply(fix_bn)
        if config.TRAIN.MIXED_PRECISION:
            self.dtype = torch.half
            self.net,self.optimizer = amp.initialize(self.net,self.optimizer,opt_level=config.TRAIN.MIXED_OPT_LEVEL)
        else:
            self.dtype = torch.float
        if self.DISTRIBUTED:
            self.net = torch.nn.parallel.DistributedDataParallel(
                self.net, device_ids=[args.local_rank], output_device=args.local_rank)

        # evaluation metrics
        self.metric = SegmentationMetric(trainset.NUM_CLASS)
        self.config = config
        self.logger = logger
        ptutil.mkdir(self.config.TRAIN.SAVE_DIR)
        model_path = os.path.join(self.config.TRAIN.SAVE_DIR,"{}_{}_{}_init.pth"
                                  .format(config.MODEL.NAME,  config.TRAIN.SEG_LOSS, config.DATASET.NAME))
        ptutil.save_model(self.net,model_path,self.logger)
Beispiel #5
0
    def __init__(self, config, args, logger):
        self.DISTRIBUTED, self.DEVICE = ptutil.init_environment(config, args)
        self.LR = config.TRAIN.LR * len(config.GPUS)  # scale by num gpus
        self.GENERATOR_LR = config.TRAIN.GENERATOR_LR * len(config.GPUS)
        self.device = torch.device(self.DEVICE)
        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
        ])
        if config.DATASET.IMG_TRANSFORM:
            data_kwargs = {
                "transform": input_transform,
                "base_size": config.DATASET.BASE_SIZE,
                "crop_size": config.DATASET.CROP_SIZE
            }
        else:
            data_kwargs = {
                "transform": None,
                "base_size": config.DATASET.BASE_SIZE,
                "crop_size": config.DATASET.CROP_SIZE
            }
        # target dataset
        targetdataset = get_segmentation_dataset('targetdataset',
                                                 split='train',
                                                 mode='train',
                                                 **data_kwargs)
        trainset = get_segmentation_dataset(config.DATASET.NAME,
                                            split=config.TRAIN.TRAIN_SPLIT,
                                            mode='train',
                                            **data_kwargs)
        self.per_iter = len(trainset) // (len(config.GPUS) *
                                          config.TRAIN.BATCH_SIZE)
        targetset_per_iter = len(targetdataset) // (len(config.GPUS) *
                                                    config.TRAIN.BATCH_SIZE)
        targetset_max_iter = config.TRAIN.EPOCHS * targetset_per_iter
        self.max_iter = config.TRAIN.epochs * self.per_iter
        if self.DISTRIBUTED:
            sampler = data.DistributedSampler(trainset)
            target_sampler = data.DistributedSampler(targetdataset)
        else:
            sampler = data.RandomSampler(trainset)
            target_sampler = data.RandomSampler(targetdataset)
        train_sampler = data.sampler.BatchSampler(sampler,
                                                  config.TRAIN.BATCH_SIZE,
                                                  True)
        train_sampler = IterationBasedBatchSampler(
            train_sampler, num_iterations=self.max_iter)
        self.train_loader = data.DataLoader(
            trainset,
            batch_sampler=train_sampler,
            pin_memory=config.DATASET.PIN_MEMORY,
            num_workers=config.DATASET.WORKERS)
        target_train_sampler = data.sampler.BatchSampler(
            target_sampler, config.TRAIN.BATCH_SIZE, True)
        target_train_sampler = IterationBasedBatchSampler(
            target_train_sampler, num_iterations=targetset_max_iter)
        self.target_loader = data.DataLoader(
            targetdataset,
            batch_sampler=target_train_sampler,
            pin_memory=False,
            num_workers=config.DATASET.WORKERS)
        self.target_trainloader_iter = enumerate(self.target_loader)
        if not config.TRAIN.SKIP_EVAL or 0 < config.TRAIN.EVAL_EPOCH < config.TRAIN.EPOCHS:
            valset = get_segmentation_dataset(config.DATASET.NAME,
                                              split='val',
                                              mode='val',
                                              **data_kwargs)
            val_sampler = make_data_sampler(valset, False, self.DISTRIBUTED)
            val_batch_sampler = data.sampler.BatchSampler(
                val_sampler, config.TEST.TEST_BATCH_SIZE, False)
            self.valid_loader = data.DataLoader(
                valset,
                batch_sampler=val_batch_sampler,
                num_workers=config.DATASET.WORKERS,
                pin_memory=False)

        # create network
        self.seg_net = get_segmentation_model(
            config.MODEL.SEG_NET, nclass=trainset.NUM_CLASS).cuda()
        self.feature_extracted = vgg19(pretrained=True)
        self.generator = get_segmentation_model(config.MODEL.TARGET_GENERATOR)

        if self.DISTRIBUTED:
            if config.TRAIN.MIXED_PRECISION:
                self.seg_net = apex.parallel.convert_syncbn_model(self.seg_net)
                self.feature_extracted = apex.parallel.convert_syncbn_model(
                    self.feature_extracted)
                self.generator = apex.parallel.convert_syncbn_model(
                    self.generator)
            else:
                self.seg_net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
                    self.seg_net)
                self.feature_extracted = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
                    self.feature_extracted)
                self.generator = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
                    self.generator)

        # resume checkpoint if needed
        if config.TRAIN.RESUME != '':
            logger.info('loading {} parameter ...'.format(
                config.MODEL.SEG_NET))
            self.seg_net = ptutil.model_resume(self.seg_net,
                                               config.TRAIN.RESUME,
                                               logger).to(self.device)
        if config.TRAIN.RESUME_GENERATOR != '':
            logger.info('loading {} parameter ...'.format(
                config.MODEL.TARGET_GENERATOR))
            self.generator = ptutil.model_resume(self.generator,
                                                 config.TRAIN.RESUME_GENERATOR,
                                                 logger).to(self.device)

        self.feature_extracted.to(self.device)
        # create criterion
        assert config.TRAIN.SEG_LOSS in (
            'focalloss2d', 'mixsoftmaxcrossentropyohemloss',
            'mixsoftmaxcrossentropy'), 'cannot support {}'.format(
                config.TRAIN.SEG_LOSS)
        if config.TRAIN.SEG_LOSS == 'focalloss2d':
            self.criterion = get_loss(config.TRAIN.SEG_LOSS,
                                      gamma=2.,
                                      use_weight=False,
                                      size_average=True,
                                      ignore_index=config.DATASET.IGNORE_INDEX)
        elif config.TRAIN.SEG_LOSS == 'mixsoftmaxcrossentropyohemloss':
            min_kept = int(config.TRAIN.BATCH_SIZE // len(config.GPUS) *
                           config.DATASET.CROP_SIZE**2 // 16)
            self.criterion = get_loss(config.TRAIN.SEG_LOSS,
                                      min_kept=min_kept,
                                      ignore_index=-1).to(self.device)
        else:
            self.criterion = get_loss(config.TRAIN.SEG_LOSS, ignore_index=-1)

        self.gen_criterion = get_loss('mseloss')
        self.kl_criterion = get_loss('criterionkldivergence')
        # optimizer and lr scheduling
        self.optimizer = optim.SGD(self.seg_net.parameters(),
                                   lr=self.LR,
                                   momentum=config.TRAIN.MOMENTUM,
                                   weight_decay=config.TRAIN.WEIGHT_DECAY)
        self.scheduler = WarmupPolyLR(self.optimizer,
                                      T_max=self.max_iter,
                                      warmup_factor=config.TRAIN.WARMUP_FACTOR,
                                      warmup_iters=config.TRAIN.WARMUP_ITERS,
                                      power=0.9)
        self.gen_optimizer = optim.SGD(self.generator.parameters(),
                                       lr=self.GENERATOR_LR,
                                       momentum=config.TRAIN.MOMENTUM,
                                       weight_decay=config.TRAIN.WEIGHT_DECAY)
        self.gen_scheduler = WarmupPolyLR(
            self.gen_optimizer,
            T_max=self.max_iter,
            warmup_factor=config.TRAIN.WARMUP_FACTOR,
            warmup_iters=config.TRAIN.WARMUP_ITERS,
            power=0.9)

        if config.TRAIN.MIXED_PRECISION:
            [self.seg_net, self.generator
             ], [self.optimizer, self.gen_optimizer
                 ] = amp.initialize([self.seg_net, self.generator],
                                    [self.optimizer, self.gen_optimizer],
                                    opt_level=config.TRAIN.MIXED_OPT_LEVEL)
            self.dtype = torch.half
        else:
            self.dtype = torch.float
        if self.DISTRIBUTED:
            self.seg_net = torch.nn.parallel.DistributedDataParallel(
                self.seg_net,
                device_ids=[args.local_rank],
                output_device=args.local_rank)
            self.generator = torch.nn.parallel.DistributedDataParallel(
                self.generator,
                device_ids=[args.local_rank],
                output_device=args.local_rank)
            self.feature_extracted = torch.nn.parallel.DistributedDataParallel(
                self.feature_extracted,
                device_ids=[args.local_rank],
                output_device=args.local_rank)

        # evaluation metrics
        self.metric = SegmentationMetric(trainset.NUM_CLASS)
        self.config = config
        self.logger = logger
        self.seg_dir = os.path.join(self.config.TRAIN.SAVE_DIR, 'seg')
        ptutil.mkdir(self.seg_dir)
        self.generator_dir = os.path.join(self.config.TRAIN.SAVE_DIR,
                                          'generator')
        ptutil.mkdir(self.generator_dir)