def __init__(self, args): self.args = args self.device = torch.device(args.device) # image transform input_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(cfg.DATASET.MEAN, cfg.DATASET.STD), ]) # dataset and dataloader data_kwargs = { 'transform': input_transform, 'base_size': cfg.TRAIN.BASE_SIZE, 'crop_size': cfg.TRAIN.CROP_SIZE } train_dataset = get_segmentation_dataset(cfg.DATASET.NAME, split='train', mode='train', **data_kwargs) val_dataset = get_segmentation_dataset(cfg.DATASET.NAME, split='val', mode=cfg.DATASET.MODE, **data_kwargs) self.iters_per_epoch = len(train_dataset) // (args.num_gpus * cfg.TRAIN.BATCH_SIZE) self.max_iters = cfg.TRAIN.EPOCHS * self.iters_per_epoch train_sampler = make_data_sampler(train_dataset, shuffle=True, distributed=args.distributed) train_batch_sampler = make_batch_data_sampler(train_sampler, cfg.TRAIN.BATCH_SIZE, self.max_iters, drop_last=True) val_sampler = make_data_sampler(val_dataset, False, args.distributed) val_batch_sampler = make_batch_data_sampler(val_sampler, cfg.TEST.BATCH_SIZE, drop_last=False) self.train_loader = data.DataLoader(dataset=train_dataset, batch_sampler=train_batch_sampler, num_workers=cfg.DATASET.WORKERS, pin_memory=True) self.val_loader = data.DataLoader(dataset=val_dataset, batch_sampler=val_batch_sampler, num_workers=cfg.DATASET.WORKERS, pin_memory=True) # create network self.model = get_segmentation_model().to(self.device) # print params and flops if get_rank() == 0: try: show_flops_params(self.model, args.device) except Exception as e: logging.warning('get flops and params error: {}'.format(e)) if cfg.MODEL.BN_TYPE not in ['BN']: logging.info( 'Batch norm type is {}, convert_sync_batchnorm is not effective' .format(cfg.MODEL.BN_TYPE)) elif args.distributed and cfg.TRAIN.SYNC_BATCH_NORM: self.model = nn.SyncBatchNorm.convert_sync_batchnorm(self.model) logging.info('SyncBatchNorm is effective!') else: logging.info('Not use SyncBatchNorm!') # create criterion self.criterion = get_segmentation_loss( cfg.MODEL.MODEL_NAME, use_ohem=cfg.SOLVER.OHEM, aux=cfg.SOLVER.AUX, aux_weight=cfg.SOLVER.AUX_WEIGHT, ignore_index=cfg.DATASET.IGNORE_INDEX).to(self.device) # optimizer, for model just includes encoder, decoder(head and auxlayer). self.optimizer = get_optimizer(self.model) # lr scheduling self.lr_scheduler = get_scheduler(self.optimizer, max_iters=self.max_iters, iters_per_epoch=self.iters_per_epoch) # resume checkpoint if needed self.start_epoch = 0 if args.resume and os.path.isfile(args.resume): name, ext = os.path.splitext(args.resume) assert ext == '.pkl' or '.pth', 'Sorry only .pth and .pkl files supported.' logging.info('Resuming training, loading {}...'.format( args.resume)) resume_sate = torch.load(args.resume) self.model.load_state_dict(resume_sate['state_dict']) self.start_epoch = resume_sate['epoch'] logging.info('resume train from epoch: {}'.format( self.start_epoch)) if resume_sate['optimizer'] is not None and resume_sate[ 'lr_scheduler'] is not None: logging.info( 'resume optimizer and lr scheduler from resume state..') self.optimizer.load_state_dict(resume_sate['optimizer']) self.lr_scheduler.load_state_dict(resume_sate['lr_scheduler']) if args.distributed: self.model = nn.parallel.DistributedDataParallel( self.model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True) # evaluation metrics self.metric = SegmentationMetric(train_dataset.num_class, args.distributed) self.best_pred = 0.0
def __init__(self, args): self.args = args self.device = torch.device(args.device) self.prefix = "ADE_cce_alpha={}".format(cfg.TRAIN.ALPHA) self.writer = SummaryWriter(log_dir=f"iccv_tensorboard/{self.prefix}") self.writer_noisy = SummaryWriter( log_dir=f"iccv_tensorboard/{self.prefix}-foggy") # image transform input_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(cfg.DATASET.MEAN, cfg.DATASET.STD), ]) # dataset and dataloader train_data_kwargs = { 'transform': input_transform, 'base_size': cfg.TRAIN.BASE_SIZE, 'crop_size': cfg.TRAIN.CROP_SIZE } val_data_kwargs = { 'transform': input_transform, 'base_size': cfg.TRAIN.BASE_SIZE, 'crop_size': cfg.TEST.CROP_SIZE } train_dataset = get_segmentation_dataset(cfg.DATASET.NAME, split='train', mode='train', **train_data_kwargs) val_dataset = get_segmentation_dataset(cfg.DATASET.NAME, split='val', mode="val", **val_data_kwargs) self.classes = val_dataset.classes self.iters_per_epoch = len(train_dataset) // (args.num_gpus * cfg.TRAIN.BATCH_SIZE) self.max_iters = cfg.TRAIN.EPOCHS * self.iters_per_epoch self.ece_evaluator = ECELoss(n_classes=len(self.classes)) self.cce_evaluator = CCELoss(n_classes=len(self.classes)) train_sampler = make_data_sampler(train_dataset, shuffle=True, distributed=args.distributed) train_batch_sampler = make_batch_data_sampler(train_sampler, cfg.TRAIN.BATCH_SIZE, self.max_iters, drop_last=True) val_sampler = make_data_sampler(val_dataset, False, args.distributed) val_batch_sampler = make_batch_data_sampler(val_sampler, cfg.TEST.BATCH_SIZE, drop_last=False) self.train_loader = data.DataLoader(dataset=train_dataset, batch_sampler=train_batch_sampler, num_workers=cfg.DATASET.WORKERS, pin_memory=True) self.val_loader = data.DataLoader(dataset=val_dataset, batch_sampler=val_batch_sampler, num_workers=cfg.DATASET.WORKERS, pin_memory=True) # DEFINE data for noisy # val_dataset_noisy = get_segmentation_dataset(cfg.DATASET.NOISY_NAME, split='val', mode="val", **train_data_kwargs) # self.val_loader_noisy = data.DataLoader(dataset=val_dataset_noisy, # batch_sampler=val_batch_sampler, # num_workers=cfg.DATASET.WORKERS, # pin_memory=True) # create network # self.model = get_segmentation_model().to(self.device) mmseg_config_file = cfg.MODEL.MMSEG_CONFIG mmseg_pretrained = cfg.TRAIN.PRETRAINED_MODEL_PATH self.model = init_segmentor(mmseg_config_file, mmseg_pretrained) self.model.to(self.device) for params in self.model.backbone.parameters(): params.requires_grad = False # print params and flops if get_rank() == 0: try: show_flops_params(copy.deepcopy(self.model), args.device) except Exception as e: logging.warning('get flops and params error: {}'.format(e)) if cfg.MODEL.BN_TYPE not in ['BN']: logging.info( 'Batch norm type is {}, convert_sync_batchnorm is not effective' .format(cfg.MODEL.BN_TYPE)) elif args.distributed and cfg.TRAIN.SYNC_BATCH_NORM: self.model = nn.SyncBatchNorm.convert_sync_batchnorm(self.model) logging.info('SyncBatchNorm is effective!') else: logging.info('Not use SyncBatchNorm!') # create criterion # self.criterion = get_segmentation_loss(cfg.MODEL.MODEL_NAME, use_ohem=cfg.SOLVER.OHEM, # aux=cfg.SOLVER.AUX, aux_weight=cfg.SOLVER.AUX_WEIGHT, # ignore_index=cfg.DATASET.IGNORE_INDEX).to(self.device) self.criterion = get_segmentation_loss( cfg.MODEL.MODEL_NAME, use_ohem=cfg.SOLVER.OHEM, aux=cfg.SOLVER.AUX, aux_weight=cfg.SOLVER.AUX_WEIGHT, ignore_index=cfg.DATASET.IGNORE_INDEX, n_classes=len(train_dataset.classes), alpha=cfg.TRAIN.ALPHA).to(self.device) # optimizer, for model just includes encoder, decoder(head and auxlayer). self.optimizer = get_optimizer_mmseg(self.model) # lr scheduling self.lr_scheduler = get_scheduler(self.optimizer, max_iters=self.max_iters, iters_per_epoch=self.iters_per_epoch) # resume checkpoint if needed self.start_epoch = 0 if args.resume and os.path.isfile(args.resume): name, ext = os.path.splitext(args.resume) assert ext == '.pkl' or '.pth', 'Sorry only .pth and .pkl files supported.' logging.info('Resuming training, loading {}...'.format( args.resume)) resume_sate = torch.load(args.resume) self.model.load_state_dict(resume_sate['state_dict']) self.start_epoch = resume_sate['epoch'] logging.info('resume train from epoch: {}'.format( self.start_epoch)) if resume_sate['optimizer'] is not None and resume_sate[ 'lr_scheduler'] is not None: logging.info( 'resume optimizer and lr scheduler from resume state..') self.optimizer.load_state_dict(resume_sate['optimizer']) self.lr_scheduler.load_state_dict(resume_sate['lr_scheduler']) # evaluation metrics self.metric = SegmentationMetric(train_dataset.num_class, args.distributed) self.best_pred_miou = 0.0 self.best_pred_cces = 1e15