コード例 #1
0
    def __init__(self, args):
        self.args = args
        self.device = torch.device(args.device)

        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(cfg.DATASET.MEAN, cfg.DATASET.STD),
        ])

        # dataset and dataloader
        val_dataset = get_segmentation_dataset(cfg.DATASET.NAME,
                                               split='val',
                                               mode='testval',
                                               transform=input_transform)
        val_sampler = make_data_sampler(val_dataset, False, args.distributed)
        val_batch_sampler = make_batch_data_sampler(
            val_sampler, images_per_batch=cfg.TEST.BATCH_SIZE, drop_last=False)
        self.val_loader = data.DataLoader(dataset=val_dataset,
                                          batch_sampler=val_batch_sampler,
                                          num_workers=cfg.DATASET.WORKERS,
                                          pin_memory=True)
        self.classes = val_dataset.classes

        # DEFINE data for noisy
        val_dataset_noisy = get_segmentation_dataset(cfg.DATASET.NOISY_NAME,
                                                     split='val',
                                                     mode='testval',
                                                     transform=input_transform)
        self.val_loader_noisy = data.DataLoader(
            dataset=val_dataset_noisy,
            batch_sampler=val_batch_sampler,
            num_workers=cfg.DATASET.WORKERS,
            pin_memory=True)

        # create network
        self.model = get_segmentation_model().to(self.device)

        if hasattr(self.model, 'encoder') and hasattr(self.model.encoder, 'named_modules') and \
            cfg.MODEL.BN_EPS_FOR_ENCODER:
            logging.info('set bn custom eps for bn in encoder: {}'.format(
                cfg.MODEL.BN_EPS_FOR_ENCODER))
            self.set_batch_norm_attr(self.model.encoder.named_modules(), 'eps',
                                     cfg.MODEL.BN_EPS_FOR_ENCODER)

        if args.distributed:
            self.model = nn.parallel.DistributedDataParallel(
                self.model,
                device_ids=[args.local_rank],
                output_device=args.local_rank,
                find_unused_parameters=True)
        self.model.to(self.device)

        self.metric = SegmentationMetric(val_dataset.num_class,
                                         args.distributed)
コード例 #2
0
    def __init__(self, args):

        self.args = args
        self.device = torch.device(args.device)

        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(cfg.DATASET.MEAN, cfg.DATASET.STD),
        ])
        self.lr = 2.5
        self.prefix = f"2_boxes_info_entropy_51_49_alpha=1_lr={self.lr}"
        # self.prefix = f"overfit__count_toy_experiment_3class_7_2_1_conf_loss=total_xavier_weights_xavier_bias_lr={self.lr}"
        self.writer = SummaryWriter(log_dir= f"cce_toy_entropy_logs/{self.prefix}")
        # self.writer = SummaryWriter(log_dir= f"cce_cityscapes_logs/{self.prefix}")
        # dataset and dataloader
        val_dataset = get_segmentation_dataset(cfg.DATASET.NAME, split='val', mode='testval', transform=input_transform)
        # val_sampler = make_data_sampler(val_dataset, False, args.distributed)
        self.val_loader = data.DataLoader(dataset=val_dataset,
                                          shuffle=True,
                                          batch_size=cfg.TEST.BATCH_SIZE,
                                          drop_last=True,
                                          num_workers=cfg.DATASET.WORKERS,
                                          pin_memory=True)

        self.dataset = val_dataset
        self.classes = val_dataset.classes
        self.metric = SegmentationMetric(val_dataset.num_class, args.distributed)
コード例 #3
0
ファイル: only_crf_eval.py プロジェクト: neelabh17/SegmenTron
    def __init__(self, args):
        # self.postprocessor= DenseCRF(iter_max=cfg.CRF.ITER_MAX,
        #                                 pos_xy_std=cfg.CRF.POS_XY_STD,
        #                                 pos_w=cfg.CRF.POS_W,
        #                                 bi_xy_std=cfg.CRF.BI_XY_STD,
        #                                 bi_rgb_std=cfg.CRF.BI_RGB_STD,
        #                                 bi_w=cfg.CRF.BI_W,
        #                             )
        
        # self.postprocessor = do_crf

        self.args = args
        self.device = torch.device(args.device)

        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(cfg.DATASET.MEAN, cfg.DATASET.STD),
        ])
        
        # dataset and dataloader
        val_dataset = get_segmentation_dataset(cfg.DATASET.NAME, split='val', mode='testval', transform=input_transform)

        self.dataset = val_dataset
        self.classes = val_dataset.classes
        self.metric = SegmentationMetric(val_dataset.num_class, args.distributed)

        self.postprocessor = CrfRnn(len(self.classes))
コード例 #4
0
    def __init__(self, args):

        self.args = args
        self.device = torch.device(args.device)

        self.n_bins = 15
        self.ece_folder = "eceData"
        # self.postfix="foggy_conv13_CityScapes_GPU"
        self.postfix = "foggy_zurich_conv13"
        # self.postfix="Foggy_1_conv13_PascalVOC_GPU"
        self.temp = 1.5
        # self.useCRF=False
        self.useCRF = True

        self.ece_criterion = metrics.IterativeECELoss()
        self.ece_criterion.make_bins(n_bins=self.n_bins)

        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(cfg.DATASET.MEAN, cfg.DATASET.STD),
        ])

        # dataset and dataloader
        val_dataset = get_segmentation_dataset(cfg.DATASET.NAME,
                                               split='val',
                                               mode='testval',
                                               transform=input_transform)
        val_sampler = make_data_sampler(val_dataset, False, args.distributed)
        val_batch_sampler = make_batch_data_sampler(
            val_sampler, images_per_batch=cfg.TEST.BATCH_SIZE, drop_last=False)
        self.val_loader = data.DataLoader(dataset=val_dataset,
                                          batch_sampler=val_batch_sampler,
                                          num_workers=cfg.DATASET.WORKERS,
                                          pin_memory=True)

        self.dataset = val_dataset
        self.classes = val_dataset.classes
        print(args.distributed)
        self.metric = SegmentationMetric(val_dataset.num_class,
                                         args.distributed)

        self.model = get_segmentation_model().to(self.device)

        if hasattr(self.model, 'encoder') and hasattr(self.model.encoder, 'named_modules') and \
            cfg.MODEL.BN_EPS_FOR_ENCODER:
            logging.info('set bn custom eps for bn in encoder: {}'.format(
                cfg.MODEL.BN_EPS_FOR_ENCODER))
            self.set_batch_norm_attr(self.model.encoder.named_modules(), 'eps',
                                     cfg.MODEL.BN_EPS_FOR_ENCODER)

        if args.distributed:
            self.model = nn.parallel.DistributedDataParallel(
                self.model,
                device_ids=[args.local_rank],
                output_device=args.local_rank,
                find_unused_parameters=True)

        self.model.to(self.device)
コード例 #5
0
    def __init__(self, args):

        self.args = args
        self.device = torch.device(args.device)

        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(cfg.DATASET.MEAN, cfg.DATASET.STD),
        ])
        self.lr = 7.5
        self.prefix = f"2_img_cce_only_lr={self.lr}"
        # self.prefix = f"overfit_with_bin_fraction_loss=no_bin_weights_ALPHA=0.5_lr={self.lr}"
        self.writer = SummaryWriter(
            log_dir=f"cce_cityscapes_conv_fcn_logs/{self.prefix}")
        # dataset and dataloader
        val_dataset = get_segmentation_dataset(cfg.DATASET.NAME,
                                               split='val',
                                               mode='testval',
                                               transform=input_transform)
        # val_sampler = make_data_sampler(val_dataset, False, args.distributed)
        self.val_loader = data.DataLoader(dataset=val_dataset,
                                          shuffle=True,
                                          batch_size=cfg.TEST.BATCH_SIZE,
                                          drop_last=True,
                                          num_workers=cfg.DATASET.WORKERS,
                                          pin_memory=True)

        self.dataset = val_dataset
        self.classes = val_dataset.classes
        self.metric = SegmentationMetric(val_dataset.num_class,
                                         args.distributed)

        self.model = get_segmentation_model().to(self.device)

        self.poolnet = poolNet(len(self.classes)).to(self.device)
        self.fcn = FCNs(self.poolnet, len(self.classes)).to(self.device)


        if hasattr(self.model, 'encoder') and hasattr(self.model.encoder, 'named_modules') and \
            cfg.MODEL.BN_EPS_FOR_ENCODER:
            logging.info('set bn custom eps for bn in encoder: {}'.format(
                cfg.MODEL.BN_EPS_FOR_ENCODER))
            self.set_batch_norm_attr(self.model.encoder.named_modules(), 'eps',
                                     cfg.MODEL.BN_EPS_FOR_ENCODER)

        if args.distributed:
            self.model = nn.parallel.DistributedDataParallel(
                self.model,
                device_ids=[args.local_rank],
                output_device=args.local_rank,
                find_unused_parameters=True)

        self.model.to(self.device)
コード例 #6
0
ファイル: new_crf_eval.py プロジェクト: neelabh17/SegmenTron
    def __init__(self, args):
        self.postprocessor= DenseCRF(iter_max=cfg.CRF.ITER_MAX,
                                        pos_xy_std=cfg.CRF.POS_XY_STD,
                                        pos_w=cfg.CRF.POS_W,
                                        bi_xy_std=cfg.CRF.BI_XY_STD,
                                        bi_rgb_std=cfg.CRF.BI_RGB_STD,
                                        bi_w=cfg.CRF.BI_W,
                                    )
        self.args = args
        self.device = torch.device(args.device)

        self.n_bins=15
        self.ece_folder="eceData"
        self.postfix="Foggy_DBF_low_DLV3Plus"
        self.temp=2.3
        self.useCRF=False

        self.ece_criterion= metrics.IterativeECELoss()
        self.ece_criterion.make_bins(n_bins=self.n_bins)

        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(cfg.DATASET.MEAN, cfg.DATASET.STD),
        ])
        
        # dataset and dataloader
        val_dataset = get_segmentation_dataset(cfg.DATASET.NAME, split='val', mode='testval', transform=input_transform)

        self.dataset = val_dataset

        # made
        # val_sampler = make_data_sampler(val_dataset, shuffle=False, distributed=args.distributed)
        # val_batch_sampler = make_batch_data_sampler(val_sampler, images_per_batch=cfg.TEST.BATCH_SIZE, drop_last=False)
        # self.val_loader = data.DataLoader(dataset=val_dataset,
        #                                   batch_sampler=val_batch_sampler,
        #                                   num_workers=cfg.DATASET.WORKERS,
        #                                   pin_memory=True)
        self.classes = val_dataset.classes
        # create network
        # self.model = get_segmentation_model().to(self.device)

        # if hasattr(self.model, 'encoder') and hasattr(self.model.encoder, 'named_modules') and \
        #     cfg.MODEL.BN_EPS_FOR_ENCODER:
        #     logging.info('set bn custom eps for bn in encoder: {}'.format(cfg.MODEL.BN_EPS_FOR_ENCODER))
        #     self.set_batch_norm_attr(self.model.encoder.named_modules(), 'eps', cfg.MODEL.BN_EPS_FOR_ENCODER)

        # if args.distributed:
        #     self.model = nn.parallel.DistributedDataParallel(self.model,
        #         device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True)
        # self.model.to(self.device)

        self.metric = SegmentationMetric(val_dataset.num_class, args.distributed)
コード例 #7
0
    def __init__(self, args):
        # self.postprocessor= DenseCRF(iter_max=cfg.CRF.ITER_MAX,
        #                                 pos_xy_std=cfg.CRF.POS_XY_STD,
        #                                 pos_w=cfg.CRF.POS_W,
        #                                 bi_xy_std=cfg.CRF.BI_XY_STD,
        #                                 bi_rgb_std=cfg.CRF.BI_RGB_STD,
        #                                 bi_w=cfg.CRF.BI_W,
        #                             )

        # self.postprocessor = do_crf

        self.args = args
        self.device = torch.device(args.device)

        self.n_bins = 15
        self.ece_folder = "eceData"
        self.postfix = "Snow_VOC_1"
        self.temp = 1.7
        self.useCRF = False
        # self.useCRF=True

        self.ece_criterion = metrics.IterativeECELoss()
        self.ece_criterion.make_bins(n_bins=self.n_bins)

        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(cfg.DATASET.MEAN, cfg.DATASET.STD),
        ])

        # dataset and dataloader
        val_dataset = get_segmentation_dataset(cfg.DATASET.NAME,
                                               split='val',
                                               mode='testval',
                                               transform=input_transform)

        self.dataset = val_dataset
        self.classes = val_dataset.classes
        self.metric = SegmentationMetric(val_dataset.num_class,
                                         args.distributed)
コード例 #8
0
ファイル: eval_mmseg.py プロジェクト: neelabh17/SegmenTron
    def __init__(self, args):
        self.args = args
        self.device = torch.device(args.device)

        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(cfg.DATASET.MEAN, cfg.DATASET.STD),
        ])

        # dataset and dataloader
        val_dataset = get_segmentation_dataset(cfg.DATASET.NAME, split='val', mode='testval', transform=input_transform)
        val_sampler = make_data_sampler(val_dataset, False, args.distributed)

        #####################
        # BATCH SIZE is always 1

        val_batch_sampler = make_batch_data_sampler(val_sampler, images_per_batch=cfg.TEST.BATCH_SIZE, drop_last=False)
        self.val_loader = data.DataLoader(dataset=val_dataset,
                                          batch_sampler=val_batch_sampler,
                                          num_workers=cfg.DATASET.WORKERS,
                                          pin_memory=True)
        self.classes = val_dataset.classes

        ### Create network ###

        # Segmentron model
        # self.model = get_segmentation_model().to(self.device)

        # MMSeg model
        mmseg_config_file = "mmseg-configs/deeplabv3plus_r101-d8_512x512_80k_ade20k.py"
        mmseg_pretrained = "pretrained_weights/deeplabv3plus_r101-d8_512x512_80k_ade20k_20200615_014139-d5730af7.pth"
        self.model = init_segmentor(mmseg_config_file, mmseg_pretrained)

        self.model.to(self.device)
        self.metric = SegmentationMetric(val_dataset.num_class, args.distributed)
コード例 #9
0
ファイル: train.py プロジェクト: zhanghongyan6553/SegmenTron
    def __init__(self, args):
        self.args = args
        self.device = torch.device(args.device)

        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(cfg.DATASET.MEAN, cfg.DATASET.STD),
        ])
        # dataset and dataloader
        data_kwargs = {
            'transform': input_transform,
            'base_size': cfg.TRAIN.BASE_SIZE,
            'crop_size': cfg.TRAIN.CROP_SIZE
        }
        train_dataset = get_segmentation_dataset(cfg.DATASET.NAME,
                                                 split='train',
                                                 mode='train',
                                                 **data_kwargs)
        val_dataset = get_segmentation_dataset(cfg.DATASET.NAME,
                                               split='val',
                                               mode=cfg.DATASET.MODE,
                                               **data_kwargs)
        self.iters_per_epoch = len(train_dataset) // (args.num_gpus *
                                                      cfg.TRAIN.BATCH_SIZE)
        self.max_iters = cfg.TRAIN.EPOCHS * self.iters_per_epoch

        train_sampler = make_data_sampler(train_dataset,
                                          shuffle=True,
                                          distributed=args.distributed)
        train_batch_sampler = make_batch_data_sampler(train_sampler,
                                                      cfg.TRAIN.BATCH_SIZE,
                                                      self.max_iters,
                                                      drop_last=True)
        val_sampler = make_data_sampler(val_dataset, False, args.distributed)
        val_batch_sampler = make_batch_data_sampler(val_sampler,
                                                    cfg.TEST.BATCH_SIZE,
                                                    drop_last=False)

        self.train_loader = data.DataLoader(dataset=train_dataset,
                                            batch_sampler=train_batch_sampler,
                                            num_workers=cfg.DATASET.WORKERS,
                                            pin_memory=True)
        self.val_loader = data.DataLoader(dataset=val_dataset,
                                          batch_sampler=val_batch_sampler,
                                          num_workers=cfg.DATASET.WORKERS,
                                          pin_memory=True)

        # create network
        self.model = get_segmentation_model().to(self.device)
        # print params and flops
        if get_rank() == 0:
            try:
                show_flops_params(self.model, args.device)
            except Exception as e:
                logging.warning('get flops and params error: {}'.format(e))

        if cfg.MODEL.BN_TYPE not in ['BN']:
            logging.info(
                'Batch norm type is {}, convert_sync_batchnorm is not effective'
                .format(cfg.MODEL.BN_TYPE))
        elif args.distributed and cfg.TRAIN.SYNC_BATCH_NORM:
            self.model = nn.SyncBatchNorm.convert_sync_batchnorm(self.model)
            logging.info('SyncBatchNorm is effective!')
        else:
            logging.info('Not use SyncBatchNorm!')

        # create criterion
        self.criterion = get_segmentation_loss(
            cfg.MODEL.MODEL_NAME,
            use_ohem=cfg.SOLVER.OHEM,
            aux=cfg.SOLVER.AUX,
            aux_weight=cfg.SOLVER.AUX_WEIGHT,
            ignore_index=cfg.DATASET.IGNORE_INDEX).to(self.device)

        # optimizer, for model just includes encoder, decoder(head and auxlayer).
        self.optimizer = get_optimizer(self.model)

        # lr scheduling
        self.lr_scheduler = get_scheduler(self.optimizer,
                                          max_iters=self.max_iters,
                                          iters_per_epoch=self.iters_per_epoch)

        # resume checkpoint if needed
        self.start_epoch = 0
        if args.resume and os.path.isfile(args.resume):
            name, ext = os.path.splitext(args.resume)
            assert ext == '.pkl' or '.pth', 'Sorry only .pth and .pkl files supported.'
            logging.info('Resuming training, loading {}...'.format(
                args.resume))
            resume_sate = torch.load(args.resume)
            self.model.load_state_dict(resume_sate['state_dict'])
            self.start_epoch = resume_sate['epoch']
            logging.info('resume train from epoch: {}'.format(
                self.start_epoch))
            if resume_sate['optimizer'] is not None and resume_sate[
                    'lr_scheduler'] is not None:
                logging.info(
                    'resume optimizer and lr scheduler from resume state..')
                self.optimizer.load_state_dict(resume_sate['optimizer'])
                self.lr_scheduler.load_state_dict(resume_sate['lr_scheduler'])

        if args.distributed:
            self.model = nn.parallel.DistributedDataParallel(
                self.model,
                device_ids=[args.local_rank],
                output_device=args.local_rank,
                find_unused_parameters=True)

        # evaluation metrics
        self.metric = SegmentationMetric(train_dataset.num_class,
                                         args.distributed)
        self.best_pred = 0.0
コード例 #10
0
    def __init__(self, args):
        self.args = args
        self.device = torch.device(args.device)

        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(cfg.DATASET.MEAN, cfg.DATASET.STD),
        ])

        # test dataloader
        val_dataset = get_segmentation_dataset(cfg.DATASET.NAME,
                                               split='test',
                                               mode='val',
                                               transform=input_transform,
                                               base_size=cfg.TRAIN.BASE_SIZE)

        # validation dataloader
        # val_dataset = get_segmentation_dataset(cfg.DATASET.NAME,
        #                                        split='validation',
        #                                        mode='val',
        #                                        transform=input_transform,
        #                                        base_size=cfg.TRAIN.BASE_SIZE)

        val_sampler = make_data_sampler(val_dataset,
                                        shuffle=False,
                                        distributed=args.distributed)
        val_batch_sampler = make_batch_data_sampler(
            val_sampler, images_per_batch=cfg.TEST.BATCH_SIZE, drop_last=False)

        self.val_loader = data.DataLoader(dataset=val_dataset,
                                          batch_sampler=val_batch_sampler,
                                          num_workers=cfg.DATASET.WORKERS,
                                          pin_memory=True)
        logging.info('**** number of images: {}. ****'.format(
            len(self.val_loader)))

        self.classes = val_dataset.classes
        # create network
        self.model = get_segmentation_model().to(self.device)

        if hasattr(self.model, 'encoder') and cfg.MODEL.BN_EPS_FOR_ENCODER:
            logging.info('set bn custom eps for bn in encoder: {}'.format(
                cfg.MODEL.BN_EPS_FOR_ENCODER))
            self.set_batch_norm_attr(self.model.encoder.named_modules(), 'eps',
                                     cfg.MODEL.BN_EPS_FOR_ENCODER)

        if args.distributed:
            self.model = nn.parallel.DistributedDataParallel(
                self.model,
                device_ids=[args.local_rank],
                output_device=args.local_rank,
                find_unused_parameters=True)

        self.model.to(self.device)
        num_gpu = args.num_gpus

        # metric of easy and hard images
        self.metric = SegmentationMetric(val_dataset.num_class,
                                         args.distributed, num_gpu)
        self.metric_easy = SegmentationMetric(val_dataset.num_class,
                                              args.distributed, num_gpu)
        self.metric_hard = SegmentationMetric(val_dataset.num_class,
                                              args.distributed, num_gpu)

        # number of easy and hard images
        self.count_easy = 0
        self.count_hard = 0
コード例 #11
0
    def __init__(self, args):
        self.args = args
        self.device = torch.device(args.device)

        self.prefix = "ADE_cce_alpha={}".format(cfg.TRAIN.ALPHA)
        self.writer = SummaryWriter(log_dir=f"iccv_tensorboard/{self.prefix}")
        self.writer_noisy = SummaryWriter(
            log_dir=f"iccv_tensorboard/{self.prefix}-foggy")

        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(cfg.DATASET.MEAN, cfg.DATASET.STD),
        ])
        # dataset and dataloader
        train_data_kwargs = {
            'transform': input_transform,
            'base_size': cfg.TRAIN.BASE_SIZE,
            'crop_size': cfg.TRAIN.CROP_SIZE
        }
        val_data_kwargs = {
            'transform': input_transform,
            'base_size': cfg.TRAIN.BASE_SIZE,
            'crop_size': cfg.TEST.CROP_SIZE
        }
        train_dataset = get_segmentation_dataset(cfg.DATASET.NAME,
                                                 split='train',
                                                 mode='train',
                                                 **train_data_kwargs)
        val_dataset = get_segmentation_dataset(cfg.DATASET.NAME,
                                               split='val',
                                               mode="val",
                                               **val_data_kwargs)

        self.classes = val_dataset.classes
        self.iters_per_epoch = len(train_dataset) // (args.num_gpus *
                                                      cfg.TRAIN.BATCH_SIZE)
        self.max_iters = cfg.TRAIN.EPOCHS * self.iters_per_epoch

        self.ece_evaluator = ECELoss(n_classes=len(self.classes))
        self.cce_evaluator = CCELoss(n_classes=len(self.classes))

        train_sampler = make_data_sampler(train_dataset,
                                          shuffle=True,
                                          distributed=args.distributed)
        train_batch_sampler = make_batch_data_sampler(train_sampler,
                                                      cfg.TRAIN.BATCH_SIZE,
                                                      self.max_iters,
                                                      drop_last=True)
        val_sampler = make_data_sampler(val_dataset, False, args.distributed)
        val_batch_sampler = make_batch_data_sampler(val_sampler,
                                                    cfg.TEST.BATCH_SIZE,
                                                    drop_last=False)

        self.train_loader = data.DataLoader(dataset=train_dataset,
                                            batch_sampler=train_batch_sampler,
                                            num_workers=cfg.DATASET.WORKERS,
                                            pin_memory=True)
        self.val_loader = data.DataLoader(dataset=val_dataset,
                                          batch_sampler=val_batch_sampler,
                                          num_workers=cfg.DATASET.WORKERS,
                                          pin_memory=True)

        # DEFINE data for noisy
        # val_dataset_noisy = get_segmentation_dataset(cfg.DATASET.NOISY_NAME, split='val', mode="val", **train_data_kwargs)
        # self.val_loader_noisy = data.DataLoader(dataset=val_dataset_noisy,
        #                                   batch_sampler=val_batch_sampler,
        #                                   num_workers=cfg.DATASET.WORKERS,
        #                                   pin_memory=True)

        # create network
        # self.model = get_segmentation_model().to(self.device)
        mmseg_config_file = cfg.MODEL.MMSEG_CONFIG
        mmseg_pretrained = cfg.TRAIN.PRETRAINED_MODEL_PATH
        self.model = init_segmentor(mmseg_config_file, mmseg_pretrained)
        self.model.to(self.device)

        for params in self.model.backbone.parameters():
            params.requires_grad = False

        # print params and flops
        if get_rank() == 0:
            try:
                show_flops_params(copy.deepcopy(self.model), args.device)
            except Exception as e:
                logging.warning('get flops and params error: {}'.format(e))

        if cfg.MODEL.BN_TYPE not in ['BN']:
            logging.info(
                'Batch norm type is {}, convert_sync_batchnorm is not effective'
                .format(cfg.MODEL.BN_TYPE))
        elif args.distributed and cfg.TRAIN.SYNC_BATCH_NORM:
            self.model = nn.SyncBatchNorm.convert_sync_batchnorm(self.model)
            logging.info('SyncBatchNorm is effective!')
        else:
            logging.info('Not use SyncBatchNorm!')

        # create criterion
        # self.criterion = get_segmentation_loss(cfg.MODEL.MODEL_NAME, use_ohem=cfg.SOLVER.OHEM,
        #                                        aux=cfg.SOLVER.AUX, aux_weight=cfg.SOLVER.AUX_WEIGHT,
        #                                        ignore_index=cfg.DATASET.IGNORE_INDEX).to(self.device)
        self.criterion = get_segmentation_loss(
            cfg.MODEL.MODEL_NAME,
            use_ohem=cfg.SOLVER.OHEM,
            aux=cfg.SOLVER.AUX,
            aux_weight=cfg.SOLVER.AUX_WEIGHT,
            ignore_index=cfg.DATASET.IGNORE_INDEX,
            n_classes=len(train_dataset.classes),
            alpha=cfg.TRAIN.ALPHA).to(self.device)

        # optimizer, for model just includes encoder, decoder(head and auxlayer).
        self.optimizer = get_optimizer_mmseg(self.model)

        # lr scheduling
        self.lr_scheduler = get_scheduler(self.optimizer,
                                          max_iters=self.max_iters,
                                          iters_per_epoch=self.iters_per_epoch)

        # resume checkpoint if needed
        self.start_epoch = 0
        if args.resume and os.path.isfile(args.resume):
            name, ext = os.path.splitext(args.resume)
            assert ext == '.pkl' or '.pth', 'Sorry only .pth and .pkl files supported.'
            logging.info('Resuming training, loading {}...'.format(
                args.resume))
            resume_sate = torch.load(args.resume)
            self.model.load_state_dict(resume_sate['state_dict'])
            self.start_epoch = resume_sate['epoch']
            logging.info('resume train from epoch: {}'.format(
                self.start_epoch))
            if resume_sate['optimizer'] is not None and resume_sate[
                    'lr_scheduler'] is not None:
                logging.info(
                    'resume optimizer and lr scheduler from resume state..')
                self.optimizer.load_state_dict(resume_sate['optimizer'])
                self.lr_scheduler.load_state_dict(resume_sate['lr_scheduler'])

        # evaluation metrics
        self.metric = SegmentationMetric(train_dataset.num_class,
                                         args.distributed)
        self.best_pred_miou = 0.0
        self.best_pred_cces = 1e15