예제 #1
0
class Evaluator(object):
    def __init__(self, args):
        self.postprocessor = DenseCRF(
            iter_max=cfg.CRF.ITER_MAX,
            pos_xy_std=cfg.CRF.POS_XY_STD,
            pos_w=cfg.CRF.POS_W,
            bi_xy_std=cfg.CRF.BI_XY_STD,
            bi_rgb_std=cfg.CRF.BI_RGB_STD,
            bi_w=cfg.CRF.BI_W,
        )
        self.args = args
        self.device = torch.device(args.device)

        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(cfg.DATASET.MEAN, cfg.DATASET.STD),
        ])

        # dataset and dataloader
        val_dataset = get_segmentation_dataset(cfg.DATASET.NAME,
                                               split='val',
                                               mode='testval',
                                               transform=input_transform)
        val_sampler = make_data_sampler(val_dataset, False, args.distributed)
        val_batch_sampler = make_batch_data_sampler(
            val_sampler, images_per_batch=cfg.TEST.BATCH_SIZE, drop_last=False)
        self.val_loader = data.DataLoader(dataset=val_dataset,
                                          batch_sampler=val_batch_sampler,
                                          num_workers=cfg.DATASET.WORKERS,
                                          pin_memory=True)
        self.classes = val_dataset.classes
        # create network
        self.model = get_segmentation_model().to(self.device)

        if hasattr(self.model, 'encoder') and hasattr(self.model.encoder, 'named_modules') and \
            cfg.MODEL.BN_EPS_FOR_ENCODER:
            logging.info('set bn custom eps for bn in encoder: {}'.format(
                cfg.MODEL.BN_EPS_FOR_ENCODER))
            self.set_batch_norm_attr(self.model.encoder.named_modules(), 'eps',
                                     cfg.MODEL.BN_EPS_FOR_ENCODER)

        if args.distributed:
            self.model = nn.parallel.DistributedDataParallel(
                self.model,
                device_ids=[args.local_rank],
                output_device=args.local_rank,
                find_unused_parameters=True)
        self.model.to(self.device)

        self.metric = SegmentationMetric(val_dataset.num_class,
                                         args.distributed)

    def set_batch_norm_attr(self, named_modules, attr, value):
        for m in named_modules:
            if isinstance(m[1], nn.BatchNorm2d) or isinstance(
                    m[1], nn.SyncBatchNorm):
                setattr(m[1], attr, value)

    def eval(self):
        self.metric.reset()
        self.model.eval()
        if self.args.distributed:
            model = self.model.module
        else:
            model = self.model

        logging.info("Start validation, Total sample: {:d}".format(
            len(self.val_loader)))
        import time
        time_start = time.time()

        tot_conf = torch.Tensor([]).reshape(-1, 1)
        tot_obj = torch.Tensor([]).reshape(-1, 1)
        tot_label_for_image = torch.Tensor([]).reshape(-1, 1)

        for i, (image, target, filename) in enumerate(self.val_loader):

            image = image.to(self.device)
            target = target.to(self.device)

            with torch.no_grad():
                output = model.evaluate(image)
            # import pdb; pdb.set_trace()

            doingCali = True
            usingCRF = True
            if (doingCali):
                # predetermined
                # temp=1.6127
                temp = 2.8
            else:
                temp = 1

            output = output / temp
            # to be removed for temp scaling
            if (not usingCRF):
                output_post = output
            else:
                output_post = []

            output = F.softmax(output, dim=1)
            output_numpy = output.cpu().numpy()

            def get_raw_image(file_location):
                # load in bgr in H W C format
                raw_image = cv2.imread(file_location,
                                       cv2.IMREAD_COLOR).astype(np.float32)
                mean_bgr = np.array([103.53, 116.28, 123.675])
                # Do some subtraction
                raw_image -= mean_bgr
                # converted to C H W
                raw_image = raw_image.transpose(2, 0, 1)
                raw_image = raw_image.astype(np.uint8)
                raw_image = raw_image.transpose(1, 2, 0)
                return raw_image

            for j, image_file_loc in enumerate(filename):

                prob_to_use = output_numpy[j]
                if (usingCRF):
                    raw_image = get_raw_image(image_file_loc)
                    prob_post = self.postprocessor(raw_image, prob_to_use)
                    prob_to_use = prob_post
                    output_post.append(prob_post)

                # import pdb;pdb.set_trace()
                prob_to_use = torch.tensor(prob_to_use)
                # Neels implementation
                labels = torch.argmax(prob_to_use, dim=0)
                conf = torch.max(prob_to_use, dim=0)[0].cpu()
                obj = labels.cpu().float()
                label_for_image = target[j].view(-1, 1).cpu().float()
                sel = (label_for_image >= 0)

                tot_conf = torch.cat(
                    [tot_conf, conf.view(-1, 1)[sel].view(-1, 1)], dim=0)
                tot_obj = torch.cat(
                    [tot_obj, obj.view(-1, 1)[sel].view(-1, 1)], dim=0)
                tot_label_for_image = torch.cat([
                    tot_label_for_image,
                    label_for_image.view(-1, 1)[sel].view(-1, 1)
                ],
                                                dim=0)

            if (usingCRF):
                output_post = np.array(output_post)
                output_post = torch.tensor(output_post)
                output_post = output_post.to(self.device)

            self.metric.update(output_post, target)
            # self.metric.update(output, target)
            pixAcc, mIoU = self.metric.get()
            logging.info(
                "Sample: {:d}, validation pixAcc: {:.3f}, mIoU: {:.3f}".format(
                    i + 1, pixAcc * 100, mIoU * 100))

        print(tot_conf.shape, tot_obj.shape, tot_label_for_image.shape)
        import pickle
        ece_folder = "eceData"
        makedirs(ece_folder)

        # postfix="DLV2_UnCal"
        postfix = "Foggy_Calibrated_DLV3Plus"
        saveDir = os.path.join(ece_folder, postfix)
        makedirs(saveDir)

        file = open(os.path.join(saveDir, "conf.pickle"), "wb")
        pickle.dump(tot_conf, file)
        file.close()
        file = open(os.path.join(saveDir, "obj.pickle"), "wb")
        pickle.dump(tot_obj, file)
        file.close()
        file = open(os.path.join(saveDir, "gt.pickle"), "wb")
        pickle.dump(tot_label_for_image, file)
        file.close()

        synchronize()
        pixAcc, mIoU, category_iou = self.metric.get(return_category_iou=True)
        logging.info('Eval use time: {:.3f} second'.format(time.time() -
                                                           time_start))
        logging.info('End validation pixAcc: {:.3f}, mIoU: {:.3f}'.format(
            pixAcc * 100, mIoU * 100))

        headers = ['class id', 'class name', 'iou']
        table = []
        for i, cls_name in enumerate(self.classes):
            table.append([cls_name, category_iou[i]])
        logging.info('Category iou: \n {}'.format(
            tabulate(table,
                     headers,
                     tablefmt='grid',
                     showindex="always",
                     numalign='center',
                     stralign='center')))
예제 #2
0
class Evaluator(object):
    def __init__(self, args):
        self.args = args
        self.device = torch.device(args.device)

        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(cfg.DATASET.MEAN, cfg.DATASET.STD),
        ])

        # dataset and dataloader
        val_dataset = get_segmentation_dataset(cfg.DATASET.NAME,
                                               split='val',
                                               mode='testval',
                                               transform=input_transform)
        val_sampler = make_data_sampler(val_dataset, False, args.distributed)
        val_batch_sampler = make_batch_data_sampler(
            val_sampler, images_per_batch=cfg.TEST.BATCH_SIZE, drop_last=False)
        self.val_loader = data.DataLoader(dataset=val_dataset,
                                          batch_sampler=val_batch_sampler,
                                          num_workers=cfg.DATASET.WORKERS,
                                          pin_memory=True)

        # create network
        self.model = get_segmentation_model().to(self.device)

        if hasattr(self.model, 'encoder') and cfg.MODEL.BN_EPS_FOR_ENCODER:
            logging.info('set bn custom eps for bn in encoder: {}'.format(
                cfg.MODEL.BN_EPS_FOR_ENCODER))
            self.set_batch_norm_attr(self.model.encoder.named_modules(), 'eps',
                                     cfg.MODEL.BN_EPS_FOR_ENCODER)

        if args.distributed:
            self.model = nn.parallel.DistributedDataParallel(
                self.model,
                device_ids=[args.local_rank],
                output_device=args.local_rank,
                find_unused_parameters=True)
        self.model.to(self.device)

        self.metric = SegmentationMetric(val_dataset.num_class,
                                         args.distributed)

    def set_batch_norm_attr(self, named_modules, attr, value):
        for m in named_modules:
            if isinstance(m[1], nn.BatchNorm2d) or isinstance(
                    m[1], nn.SyncBatchNorm):
                setattr(m[1], attr, value)

    def eval(self):
        self.metric.reset()
        self.model.eval()
        if self.args.distributed:
            model = self.model.module
        else:
            model = self.model

        logging.info("Start validation, Total sample: {:d}".format(
            len(self.val_loader)))
        for i, (image, target, filename) in enumerate(self.val_loader):
            image = image.to(self.device)
            target = target.to(self.device)

            with torch.no_grad():
                size = image.size()[2:]
                if size[0] < cfg.TEST.CROP_SIZE[0] and size[
                        1] < cfg.TEST.CROP_SIZE[1]:
                    pad_height = cfg.TEST.CROP_SIZE[0] - size[0]
                    pad_width = cfg.TEST.CROP_SIZE[1] - size[1]
                    image = F.pad(image, (0, pad_height, 0, pad_width))
                    output = model(image)[0]
                    output = output[..., :size[0], :size[1]]
                else:
                    output = model(image)[0]

            self.metric.update(output, target)
            pixAcc, mIoU = self.metric.get()
            logging.info(
                "Sample: {:d}, validation pixAcc: {:.3f}, mIoU: {:.3f}".format(
                    i + 1, pixAcc * 100, mIoU * 100))

            # Todo
            # if self.args.save_pred:
            #     pred = torch.argmax(output, 1)
            #     pred = pred.cpu().data.numpy()
            #
            #     predict = pred.squeeze(0)
            #     mask = get_color_pallete(predict, self.args.dataset)
            #     mask.save(os.path.join(outdir, os.path.splitext(filename[0])[0] + '.png'))
        synchronize()
예제 #3
0
class Trainer(object):
    def __init__(self, args):
        self.args = args
        self.device = torch.device(args.device)

        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(cfg.DATASET.MEAN, cfg.DATASET.STD),
        ])
        # dataset and dataloader
        data_kwargs = {
            'transform': input_transform,
            'base_size': cfg.TRAIN.BASE_SIZE,
            'crop_size': cfg.TRAIN.CROP_SIZE
        }
        train_dataset = get_segmentation_dataset(cfg.DATASET.NAME,
                                                 split='train',
                                                 mode='train',
                                                 **data_kwargs)
        val_dataset = get_segmentation_dataset(cfg.DATASET.NAME,
                                               split='val',
                                               mode=cfg.DATASET.MODE,
                                               **data_kwargs)
        self.iters_per_epoch = len(train_dataset) // (args.num_gpus *
                                                      cfg.TRAIN.BATCH_SIZE)
        self.max_iters = cfg.TRAIN.EPOCHS * self.iters_per_epoch

        train_sampler = make_data_sampler(train_dataset,
                                          shuffle=True,
                                          distributed=args.distributed)
        train_batch_sampler = make_batch_data_sampler(train_sampler,
                                                      cfg.TRAIN.BATCH_SIZE,
                                                      self.max_iters,
                                                      drop_last=True)
        val_sampler = make_data_sampler(val_dataset, False, args.distributed)
        val_batch_sampler = make_batch_data_sampler(val_sampler,
                                                    cfg.TEST.BATCH_SIZE,
                                                    drop_last=False)

        self.train_loader = data.DataLoader(dataset=train_dataset,
                                            batch_sampler=train_batch_sampler,
                                            num_workers=cfg.DATASET.WORKERS,
                                            pin_memory=True)
        self.val_loader = data.DataLoader(dataset=val_dataset,
                                          batch_sampler=val_batch_sampler,
                                          num_workers=cfg.DATASET.WORKERS,
                                          pin_memory=True)

        # create network
        self.model = get_segmentation_model().to(self.device)
        # print params and flops
        if get_rank() == 0:
            try:
                show_flops_params(self.model, args.device)
            except Exception as e:
                logging.warning('get flops and params error: {}'.format(e))

        if cfg.MODEL.BN_TYPE not in ['BN']:
            logging.info(
                'Batch norm type is {}, convert_sync_batchnorm is not effective'
                .format(cfg.MODEL.BN_TYPE))
        elif args.distributed and cfg.TRAIN.SYNC_BATCH_NORM:
            self.model = nn.SyncBatchNorm.convert_sync_batchnorm(self.model)
            logging.info('SyncBatchNorm is effective!')
        else:
            logging.info('Not use SyncBatchNorm!')

        # create criterion
        self.criterion = get_segmentation_loss(
            cfg.MODEL.MODEL_NAME,
            use_ohem=cfg.SOLVER.OHEM,
            aux=cfg.SOLVER.AUX,
            aux_weight=cfg.SOLVER.AUX_WEIGHT,
            ignore_index=cfg.DATASET.IGNORE_INDEX).to(self.device)

        # optimizer, for model just includes encoder, decoder(head and auxlayer).
        self.optimizer = get_optimizer(self.model)

        # lr scheduling
        self.lr_scheduler = get_scheduler(self.optimizer,
                                          max_iters=self.max_iters,
                                          iters_per_epoch=self.iters_per_epoch)

        # resume checkpoint if needed
        self.start_epoch = 0
        if args.resume and os.path.isfile(args.resume):
            name, ext = os.path.splitext(args.resume)
            assert ext == '.pkl' or '.pth', 'Sorry only .pth and .pkl files supported.'
            logging.info('Resuming training, loading {}...'.format(
                args.resume))
            resume_sate = torch.load(args.resume)
            self.model.load_state_dict(resume_sate['state_dict'])
            self.start_epoch = resume_sate['epoch']
            logging.info('resume train from epoch: {}'.format(
                self.start_epoch))
            if resume_sate['optimizer'] is not None and resume_sate[
                    'lr_scheduler'] is not None:
                logging.info(
                    'resume optimizer and lr scheduler from resume state..')
                self.optimizer.load_state_dict(resume_sate['optimizer'])
                self.lr_scheduler.load_state_dict(resume_sate['lr_scheduler'])

        if args.distributed:
            self.model = nn.parallel.DistributedDataParallel(
                self.model,
                device_ids=[args.local_rank],
                output_device=args.local_rank,
                find_unused_parameters=True)

        # evaluation metrics
        self.metric = SegmentationMetric(train_dataset.num_class,
                                         args.distributed)
        self.best_pred = 0.0

    def train(self):
        self.save_to_disk = get_rank() == 0
        epochs, max_iters, iters_per_epoch = cfg.TRAIN.EPOCHS, self.max_iters, self.iters_per_epoch
        log_per_iters, val_per_iters = self.args.log_iter, self.args.val_epoch * self.iters_per_epoch

        start_time = time.time()
        logging.info(
            'Start training, Total Epochs: {:d} = Total Iterations {:d}'.
            format(epochs, max_iters))

        self.model.train()
        iteration = self.start_epoch * iters_per_epoch if self.start_epoch > 0 else 0
        for (images, targets, _) in self.train_loader:
            epoch = iteration // iters_per_epoch + 1
            iteration += 1

            images = images.to(self.device)
            targets = targets.to(self.device)

            outputs = self.model(images)
            loss_dict = self.criterion(outputs, targets)

            losses = sum(loss for loss in loss_dict.values())

            # reduce losses over all GPUs for logging purposes
            loss_dict_reduced = reduce_loss_dict(loss_dict)
            losses_reduced = sum(loss for loss in loss_dict_reduced.values())

            self.optimizer.zero_grad()
            losses.backward()
            self.optimizer.step()
            self.lr_scheduler.step()

            eta_seconds = ((time.time() - start_time) /
                           iteration) * (max_iters - iteration)
            eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))

            if iteration % log_per_iters == 0 and self.save_to_disk:
                logging.info(
                    "Epoch: {:d}/{:d} || Iters: {:d}/{:d} || Lr: {:.6f} || "
                    "Loss: {:.4f} || Cost Time: {} || Estimated Time: {}".
                    format(
                        epoch, epochs, iteration % iters_per_epoch,
                        iters_per_epoch, self.optimizer.param_groups[0]['lr'],
                        losses_reduced.item(),
                        str(
                            datetime.timedelta(seconds=int(time.time() -
                                                           start_time))),
                        eta_string))

            if iteration % self.iters_per_epoch == 0 and self.save_to_disk:
                save_checkpoint(self.model,
                                epoch,
                                self.optimizer,
                                self.lr_scheduler,
                                is_best=False)

            if not self.args.skip_val and iteration % val_per_iters == 0:
                self.validation(epoch)
                self.model.train()

        total_training_time = time.time() - start_time
        total_training_str = str(
            datetime.timedelta(seconds=total_training_time))
        logging.info("Total training time: {} ({:.4f}s / it)".format(
            total_training_str, total_training_time / max_iters))

    def validation(self, epoch):
        self.metric.reset()
        if self.args.distributed:
            model = self.model.module
        else:
            model = self.model
        torch.cuda.empty_cache()
        model.eval()
        for i, (image, target, filename) in enumerate(self.val_loader):
            image = image.to(self.device)
            target = target.to(self.device)

            with torch.no_grad():
                if cfg.DATASET.MODE == 'val' or cfg.TEST.CROP_SIZE is None:
                    output = model(image)[0]
                else:
                    size = image.size()[2:]
                    pad_height = cfg.TEST.CROP_SIZE[0] - size[0]
                    pad_width = cfg.TEST.CROP_SIZE[1] - size[1]
                    image = F.pad(image, (0, pad_height, 0, pad_width))
                    output = model(image)[0]
                    output = output[..., :size[0], :size[1]]

            self.metric.update(output, target)
            pixAcc, mIoU = self.metric.get()
            logging.info(
                "[EVAL] Sample: {:d}, pixAcc: {:.3f}, mIoU: {:.3f}".format(
                    i + 1, pixAcc * 100, mIoU * 100))
        pixAcc, mIoU = self.metric.get()
        logging.info(
            "[EVAL END] Epoch: {:d}, pixAcc: {:.3f}, mIoU: {:.3f}".format(
                epoch, pixAcc * 100, mIoU * 100))
        synchronize()
        if self.best_pred < mIoU and self.save_to_disk:
            self.best_pred = mIoU
            logging.info(
                'Epoch {} is the best model, best pixAcc: {:.3f}, mIoU: {:.3f}, save the model..'
                .format(epoch, pixAcc * 100, mIoU * 100))
            save_checkpoint(model, epoch, is_best=True)
예제 #4
0
class Evaluator(object):
    def __init__(self, args):
        self.args = args
        self.device = torch.device(args.device)

        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(cfg.DATASET.MEAN, cfg.DATASET.STD),
        ])

        # dataset and dataloader
        val_dataset = get_segmentation_dataset(cfg.DATASET.NAME,
                                               split='test',
                                               mode='testval',
                                               transform=input_transform)
        val_sampler = make_data_sampler(val_dataset, False, args.distributed)
        val_batch_sampler = make_batch_data_sampler(
            val_sampler, images_per_batch=cfg.TEST.BATCH_SIZE, drop_last=False)
        self.val_loader = data.DataLoader(dataset=val_dataset,
                                          batch_sampler=val_batch_sampler,
                                          num_workers=cfg.DATASET.WORKERS,
                                          pin_memory=True)
        self.classes = val_dataset.classes
        # create network
        self.model = get_segmentation_model().to(self.device)

        if hasattr(self.model, 'encoder') and hasattr(self.model.encoder, 'named_modules') and \
            cfg.MODEL.BN_EPS_FOR_ENCODER:
            logging.info('set bn custom eps for bn in encoder: {}'.format(
                cfg.MODEL.BN_EPS_FOR_ENCODER))
            self.set_batch_norm_attr(self.model.encoder.named_modules(), 'eps',
                                     cfg.MODEL.BN_EPS_FOR_ENCODER)

        if args.distributed:
            self.model = nn.parallel.DistributedDataParallel(
                self.model,
                device_ids=[args.local_rank],
                output_device=args.local_rank,
                find_unused_parameters=True)
        self.model.to(self.device)

        self.metric = SegmentationMetric(val_dataset.num_class,
                                         args.distributed)

    def set_batch_norm_attr(self, named_modules, attr, value):
        for m in named_modules:
            if isinstance(m[1], nn.BatchNorm2d) or isinstance(
                    m[1], nn.SyncBatchNorm):
                setattr(m[1], attr, value)

    def eval(self):
        self.metric.reset()
        self.model.eval()
        if self.args.distributed:
            model = self.model.module
        else:
            model = self.model

        logging.info("Start validation, Total sample: {:d}".format(
            len(self.val_loader)))
        import time
        time_start = time.time()
        for i, (image, target, filename) in enumerate(self.val_loader):
            image = image.to(self.device)
            target = target.to(self.device)

            with torch.no_grad():
                output = model.evaluate(image)

            self.metric.update(output, target)
            pixAcc, mIoU = self.metric.get()
            logging.info(
                "Sample: {:d}, validation pixAcc: {:.3f}, mIoU: {:.3f}".format(
                    i + 1, pixAcc * 100, mIoU * 100))

        synchronize()
        pixAcc, mIoU, category_iou = self.metric.get(return_category_iou=True)
        logging.info('Eval use time: {:.3f} second'.format(time.time() -
                                                           time_start))
        logging.info('End validation pixAcc: {:.3f}, mIoU: {:.3f}'.format(
            pixAcc * 100, mIoU * 100))

        headers = ['class id', 'class name', 'iou']
        table = []
        for i, cls_name in enumerate(self.classes):
            table.append([cls_name, category_iou[i]])
        logging.info('Category iou: \n {}'.format(
            tabulate(table,
                     headers,
                     tablefmt='grid',
                     showindex="always",
                     numalign='center',
                     stralign='center')))
예제 #5
0
class Evaluator(object):
    def __init__(self, args):
        self.args = args
        self.device = torch.device(args.device)

        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(cfg.DATASET.MEAN, cfg.DATASET.STD),
        ])

        # test dataloader
        val_dataset = get_segmentation_dataset(cfg.DATASET.NAME,
                                               split='test',
                                               mode='val',
                                               transform=input_transform,
                                               base_size=cfg.TRAIN.BASE_SIZE)

        # validation dataloader
        # val_dataset = get_segmentation_dataset(cfg.DATASET.NAME,
        #                                        split='validation',
        #                                        mode='val',
        #                                        transform=input_transform,
        #                                        base_size=cfg.TRAIN.BASE_SIZE)

        val_sampler = make_data_sampler(val_dataset,
                                        shuffle=False,
                                        distributed=args.distributed)
        val_batch_sampler = make_batch_data_sampler(
            val_sampler, images_per_batch=cfg.TEST.BATCH_SIZE, drop_last=False)

        self.val_loader = data.DataLoader(dataset=val_dataset,
                                          batch_sampler=val_batch_sampler,
                                          num_workers=cfg.DATASET.WORKERS,
                                          pin_memory=True)
        logging.info('**** number of images: {}. ****'.format(
            len(self.val_loader)))

        self.classes = val_dataset.classes
        # create network
        self.model = get_segmentation_model().to(self.device)

        if hasattr(self.model, 'encoder') and cfg.MODEL.BN_EPS_FOR_ENCODER:
            logging.info('set bn custom eps for bn in encoder: {}'.format(
                cfg.MODEL.BN_EPS_FOR_ENCODER))
            self.set_batch_norm_attr(self.model.encoder.named_modules(), 'eps',
                                     cfg.MODEL.BN_EPS_FOR_ENCODER)

        if args.distributed:
            self.model = nn.parallel.DistributedDataParallel(
                self.model,
                device_ids=[args.local_rank],
                output_device=args.local_rank,
                find_unused_parameters=True)

        self.model.to(self.device)
        num_gpu = args.num_gpus

        # metric of easy and hard images
        self.metric = SegmentationMetric(val_dataset.num_class,
                                         args.distributed, num_gpu)
        self.metric_easy = SegmentationMetric(val_dataset.num_class,
                                              args.distributed, num_gpu)
        self.metric_hard = SegmentationMetric(val_dataset.num_class,
                                              args.distributed, num_gpu)

        # number of easy and hard images
        self.count_easy = 0
        self.count_hard = 0

    def set_batch_norm_attr(self, named_modules, attr, value):
        for m in named_modules:
            if isinstance(m[1], nn.BatchNorm2d) or isinstance(
                    m[1], nn.SyncBatchNorm):
                setattr(m[1], attr, value)

    def eval(self):
        self.metric.reset()
        self.model.eval()
        if self.args.distributed:
            model = self.model.module
        else:
            model = self.model

        logging.info("Start validation, Total sample: {:d}".format(
            len(self.val_loader)))
        import time
        time_start = time.time()
        widgets = [
            'Inference: ',
            Percentage(), ' ',
            Bar('#'), ' ',
            Timer(), ' ',
            ETA(), ' ',
            FileTransferSpeed()
        ]
        pbar = ProgressBar(widgets=widgets,
                           maxval=10 * len(self.val_loader)).start()

        for i, (image, target, boundary,
                filename) in enumerate(self.val_loader):
            image = image.to(self.device)
            target = target.to(self.device)
            boundary = boundary.to(self.device)

            filename = filename[0]
            with torch.no_grad():
                output, output_boundary = model.evaluate(image)

            if 'hard' in filename:
                self.metric_hard.update(output, target)
                self.count_hard += 1
            elif 'easy' in filename:
                self.metric_easy.update(output, target)
                self.count_easy += 1
            else:
                print(filename)
                continue

            self.metric.update(output, target)
            pbar.update(10 * i + 1)

        pbar.finish()
        synchronize()
        pixAcc, mIoU, category_iou, mae, mBer, category_Ber = self.metric.get(
            return_category_iou=True)
        pixAcc_e, mIoU_e, category_iou_e, mae_e, mBer_e, category_Ber_e = self.metric_easy.get(
            return_category_iou=True)
        pixAcc_h, mIoU_h, category_iou_h, mae_h, mBer_h, category_Ber_h = self.metric_hard.get(
            return_category_iou=True)

        logging.info('Eval use time: {:.3f} second'.format(time.time() -
                                                           time_start))
        logging.info(
            'End validation pixAcc: {:.2f}, mIoU: {:.2f}, mae: {:.3f}, mBer: {:.2f}'
            .format(pixAcc * 100, mIoU * 100, mae, mBer))
        logging.info(
            'End validation easy pixAcc: {:.2f}, mIoU: {:.2f}, mae: {:.3f}, mBer: {:.2f}'
            .format(pixAcc_e * 100, mIoU_e * 100, mae_e, mBer_e))
        logging.info(
            'End validation hard pixAcc: {:.2f}, mIoU: {:.2f}, mae: {:.3f}, mBer: {:.2f}'
            .format(pixAcc_h * 100, mIoU_h * 100, mae_h, mBer_h))

        headers = [
            'class id', 'class name', 'iou', 'iou_easy', 'iou_hard', 'ber',
            'ber_easy', 'ber_hard'
        ]
        table = []
        for i, cls_name in enumerate(self.classes):
            table.append([
                cls_name, category_iou[i], category_iou_e[i],
                category_iou_h[i], category_Ber[i], category_Ber_e[i],
                category_Ber_h[i]
            ])
        logging.info('Category iou: \n {}'.format(
            tabulate(table,
                     headers,
                     tablefmt='grid',
                     showindex="always",
                     numalign='center',
                     stralign='center')))
        logging.info('easy images: {}, hard images: {}'.format(
            self.count_easy, self.count_hard))
예제 #6
0
class Evaluator(object):
    def __init__(self, args):
        self.postprocessor = DenseCRF(
            iter_max=cfg.CRF.ITER_MAX,
            pos_xy_std=cfg.CRF.POS_XY_STD,
            pos_w=cfg.CRF.POS_W,
            bi_xy_std=cfg.CRF.BI_XY_STD,
            bi_rgb_std=cfg.CRF.BI_RGB_STD,
            bi_w=cfg.CRF.BI_W,
        )
        self.args = args
        self.device = torch.device(args.device)

        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(cfg.DATASET.MEAN, cfg.DATASET.STD),
        ])

        # dataset and dataloader
        val_dataset = get_segmentation_dataset(cfg.DATASET.NAME,
                                               split='val',
                                               mode='testval',
                                               transform=input_transform)

        # made
        val_sampler = make_data_sampler(val_dataset,
                                        shuffle=False,
                                        distributed=args.distributed)
        val_batch_sampler = make_batch_data_sampler(
            val_sampler, images_per_batch=cfg.TEST.BATCH_SIZE, drop_last=False)
        self.val_loader = data.DataLoader(dataset=val_dataset,
                                          batch_sampler=val_batch_sampler,
                                          num_workers=cfg.DATASET.WORKERS,
                                          pin_memory=True)
        self.classes = val_dataset.classes
        # create network
        self.model = get_segmentation_model().to(self.device)

        if hasattr(self.model, 'encoder') and hasattr(self.model.encoder, 'named_modules') and \
            cfg.MODEL.BN_EPS_FOR_ENCODER:
            logging.info('set bn custom eps for bn in encoder: {}'.format(
                cfg.MODEL.BN_EPS_FOR_ENCODER))
            self.set_batch_norm_attr(self.model.encoder.named_modules(), 'eps',
                                     cfg.MODEL.BN_EPS_FOR_ENCODER)

        if args.distributed:
            self.model = nn.parallel.DistributedDataParallel(
                self.model,
                device_ids=[args.local_rank],
                output_device=args.local_rank,
                find_unused_parameters=True)
        self.model.to(self.device)

        self.metric = SegmentationMetric(val_dataset.num_class,
                                         args.distributed)

    def set_batch_norm_attr(self, named_modules, attr, value):
        for m in named_modules:
            if isinstance(m[1], nn.BatchNorm2d) or isinstance(
                    m[1], nn.SyncBatchNorm):
                setattr(m[1], attr, value)

    def eval(self):
        self.metric.reset()
        self.model.eval()
        if self.args.distributed:
            model = self.model.module
        else:
            model = self.model

        logging.info("Start validation, Total sample: {:d}".format(
            len(self.val_loader)))
        import time
        time_start = time.time()
        for i, (image, target, filename) in enumerate(self.val_loader):

            image = image.to(self.device)
            target = target.to(self.device)

            with torch.no_grad():
                output = model.evaluate(image)
            # import pdb; pdb.set_trace()

            # do operations here, NOTE : We are saving with batch size of 1
            # np.save('npy_files_voc/' + os.path.basename(filename[0]).strip('.jpg'), output[0].cpu().numpy())

            output = F.interpolate(output, (image.shape[2], image.shape[3]),
                                   mode='bilinear',
                                   align_corners=True)
            output = torch.argmax(output, 1)

            self.metric.update(output, target)

            pixAcc, mIoU = self.metric.get()
            logging.info(
                "Sample: {:d}, validation pixAcc: {:.3f}, mIoU: {:.3f}".format(
                    i + 1, pixAcc * 100, mIoU * 100))

        synchronize()
        pixAcc, mIoU, category_iou = self.metric.get(return_category_iou=True)
        logging.info('Eval use time: {:.3f} second'.format(time.time() -
                                                           time_start))
        logging.info('End validation pixAcc: {:.3f}, mIoU: {:.3f}'.format(
            pixAcc * 100, mIoU * 100))

        headers = ['class id', 'class name', 'iou']
        table = []
        for i, cls_name in enumerate(self.classes):
            table.append([cls_name, category_iou[i]])
        logging.info('Category iou: \n {}'.format(
            tabulate(table,
                     headers,
                     tablefmt='grid',
                     showindex="always",
                     numalign='center',
                     stralign='center')))
예제 #7
0
class Trainer(object):
    def __init__(self, args):
        self.args = args
        self.device = torch.device(args.device)

        self.prefix = "ADE_cce_alpha={}".format(cfg.TRAIN.ALPHA)
        self.writer = SummaryWriter(log_dir=f"iccv_tensorboard/{self.prefix}")
        self.writer_noisy = SummaryWriter(
            log_dir=f"iccv_tensorboard/{self.prefix}-foggy")

        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(cfg.DATASET.MEAN, cfg.DATASET.STD),
        ])
        # dataset and dataloader
        train_data_kwargs = {
            'transform': input_transform,
            'base_size': cfg.TRAIN.BASE_SIZE,
            'crop_size': cfg.TRAIN.CROP_SIZE
        }
        val_data_kwargs = {
            'transform': input_transform,
            'base_size': cfg.TRAIN.BASE_SIZE,
            'crop_size': cfg.TEST.CROP_SIZE
        }
        train_dataset = get_segmentation_dataset(cfg.DATASET.NAME,
                                                 split='train',
                                                 mode='train',
                                                 **train_data_kwargs)
        val_dataset = get_segmentation_dataset(cfg.DATASET.NAME,
                                               split='val',
                                               mode="val",
                                               **val_data_kwargs)

        self.classes = val_dataset.classes
        self.iters_per_epoch = len(train_dataset) // (args.num_gpus *
                                                      cfg.TRAIN.BATCH_SIZE)
        self.max_iters = cfg.TRAIN.EPOCHS * self.iters_per_epoch

        self.ece_evaluator = ECELoss(n_classes=len(self.classes))
        self.cce_evaluator = CCELoss(n_classes=len(self.classes))

        train_sampler = make_data_sampler(train_dataset,
                                          shuffle=True,
                                          distributed=args.distributed)
        train_batch_sampler = make_batch_data_sampler(train_sampler,
                                                      cfg.TRAIN.BATCH_SIZE,
                                                      self.max_iters,
                                                      drop_last=True)
        val_sampler = make_data_sampler(val_dataset, False, args.distributed)
        val_batch_sampler = make_batch_data_sampler(val_sampler,
                                                    cfg.TEST.BATCH_SIZE,
                                                    drop_last=False)

        self.train_loader = data.DataLoader(dataset=train_dataset,
                                            batch_sampler=train_batch_sampler,
                                            num_workers=cfg.DATASET.WORKERS,
                                            pin_memory=True)
        self.val_loader = data.DataLoader(dataset=val_dataset,
                                          batch_sampler=val_batch_sampler,
                                          num_workers=cfg.DATASET.WORKERS,
                                          pin_memory=True)

        # DEFINE data for noisy
        # val_dataset_noisy = get_segmentation_dataset(cfg.DATASET.NOISY_NAME, split='val', mode="val", **train_data_kwargs)
        # self.val_loader_noisy = data.DataLoader(dataset=val_dataset_noisy,
        #                                   batch_sampler=val_batch_sampler,
        #                                   num_workers=cfg.DATASET.WORKERS,
        #                                   pin_memory=True)

        # create network
        # self.model = get_segmentation_model().to(self.device)
        mmseg_config_file = cfg.MODEL.MMSEG_CONFIG
        mmseg_pretrained = cfg.TRAIN.PRETRAINED_MODEL_PATH
        self.model = init_segmentor(mmseg_config_file, mmseg_pretrained)
        self.model.to(self.device)

        for params in self.model.backbone.parameters():
            params.requires_grad = False

        # print params and flops
        if get_rank() == 0:
            try:
                show_flops_params(copy.deepcopy(self.model), args.device)
            except Exception as e:
                logging.warning('get flops and params error: {}'.format(e))

        if cfg.MODEL.BN_TYPE not in ['BN']:
            logging.info(
                'Batch norm type is {}, convert_sync_batchnorm is not effective'
                .format(cfg.MODEL.BN_TYPE))
        elif args.distributed and cfg.TRAIN.SYNC_BATCH_NORM:
            self.model = nn.SyncBatchNorm.convert_sync_batchnorm(self.model)
            logging.info('SyncBatchNorm is effective!')
        else:
            logging.info('Not use SyncBatchNorm!')

        # create criterion
        # self.criterion = get_segmentation_loss(cfg.MODEL.MODEL_NAME, use_ohem=cfg.SOLVER.OHEM,
        #                                        aux=cfg.SOLVER.AUX, aux_weight=cfg.SOLVER.AUX_WEIGHT,
        #                                        ignore_index=cfg.DATASET.IGNORE_INDEX).to(self.device)
        self.criterion = get_segmentation_loss(
            cfg.MODEL.MODEL_NAME,
            use_ohem=cfg.SOLVER.OHEM,
            aux=cfg.SOLVER.AUX,
            aux_weight=cfg.SOLVER.AUX_WEIGHT,
            ignore_index=cfg.DATASET.IGNORE_INDEX,
            n_classes=len(train_dataset.classes),
            alpha=cfg.TRAIN.ALPHA).to(self.device)

        # optimizer, for model just includes encoder, decoder(head and auxlayer).
        self.optimizer = get_optimizer_mmseg(self.model)

        # lr scheduling
        self.lr_scheduler = get_scheduler(self.optimizer,
                                          max_iters=self.max_iters,
                                          iters_per_epoch=self.iters_per_epoch)

        # resume checkpoint if needed
        self.start_epoch = 0
        if args.resume and os.path.isfile(args.resume):
            name, ext = os.path.splitext(args.resume)
            assert ext == '.pkl' or '.pth', 'Sorry only .pth and .pkl files supported.'
            logging.info('Resuming training, loading {}...'.format(
                args.resume))
            resume_sate = torch.load(args.resume)
            self.model.load_state_dict(resume_sate['state_dict'])
            self.start_epoch = resume_sate['epoch']
            logging.info('resume train from epoch: {}'.format(
                self.start_epoch))
            if resume_sate['optimizer'] is not None and resume_sate[
                    'lr_scheduler'] is not None:
                logging.info(
                    'resume optimizer and lr scheduler from resume state..')
                self.optimizer.load_state_dict(resume_sate['optimizer'])
                self.lr_scheduler.load_state_dict(resume_sate['lr_scheduler'])

        # evaluation metrics
        self.metric = SegmentationMetric(train_dataset.num_class,
                                         args.distributed)
        self.best_pred_miou = 0.0
        self.best_pred_cces = 1e15

    def train(self):
        self.save_to_disk = get_rank() == 0
        epochs, max_iters, iters_per_epoch = cfg.TRAIN.EPOCHS, self.max_iters, self.iters_per_epoch
        log_per_iters, val_per_iters = self.args.log_iter, self.args.val_epoch * self.iters_per_epoch

        start_time = time.time()
        logging.info(
            'Start training, Total Epochs: {:d} = Total Iterations {:d}'.
            format(epochs, max_iters))

        self.model.train()
        iteration = self.start_epoch * iters_per_epoch if self.start_epoch > 0 else 0
        for (images, targets, _) in self.train_loader:
            epoch = iteration // iters_per_epoch + 1
            iteration += 1

            images = images.to(self.device)
            targets = targets.to(self.device)

            outputs = self.model.encode_decode(images, None)
            # @jatin why did this?
            outputs = [outputs]

            # loss_dict = self.criterion(outputs, targets)
            loss_dict, loss_cal, loss_nll = self.criterion(outputs, targets)

            losses = sum(loss for loss in loss_dict.values())

            # reduce losses over all GPUs for logging purposes
            loss_dict_reduced = reduce_loss_dict(loss_dict)
            losses_reduced = sum(loss for loss in loss_dict_reduced.values())

            self.optimizer.zero_grad()
            losses.backward()
            self.optimizer.step()
            self.lr_scheduler.step()

            eta_seconds = ((time.time() - start_time) /
                           iteration) * (max_iters - iteration)
            eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))

            if iteration % log_per_iters == 0 and self.save_to_disk:
                logging.info(
                    "Epoch: {:d}/{:d} || Iters: {:d}/{:d} || Lr: {:.6f} || "
                    "Loss: {:.4f} || Cost Time: {} || Estimated Time: {}".
                    format(
                        epoch, epochs, iteration % iters_per_epoch,
                        iters_per_epoch, self.optimizer.param_groups[0]['lr'],
                        losses_reduced.item(),
                        str(
                            datetime.timedelta(seconds=int(time.time() -
                                                           start_time))),
                        eta_string))

                self.writer.add_scalar("Loss", losses_reduced.item(),
                                       iteration)
                self.writer.add_scalar("CCE oe ECE part of Loss",
                                       loss_cal.item(), iteration)
                self.writer.add_scalar("NLL Part", loss_nll.item(), iteration)
                self.writer.add_scalar("LR",
                                       self.optimizer.param_groups[0]['lr'],
                                       iteration)

            if iteration % self.iters_per_epoch == 0 and self.save_to_disk:
                save_checkpoint(self.model,
                                epoch,
                                self.optimizer,
                                self.lr_scheduler,
                                is_best=False)

            if not self.args.skip_val and iteration % val_per_iters == 0:
                self.validation(epoch, self.val_loader, self.writer)
                # self.validation(epoch, self.val_loader_noisy, self.writer_noisy)
                self.model.train()

        total_training_time = time.time() - start_time
        total_training_str = str(
            datetime.timedelta(seconds=total_training_time))
        logging.info("Total training time: {} ({:.4f}s / it)".format(
            total_training_str, total_training_time / max_iters))

    def validation(self, epoch, val_loader, writer):
        self.metric.reset()
        self.ece_evaluator.reset()
        self.cce_evaluator.reset()

        model = self.model
        torch.cuda.empty_cache()
        model.eval()
        for i, (image, target, filename) in enumerate(self.val_loader):
            image = image.to(self.device)
            target = target.to(self.device)

            with torch.no_grad():
                # output = mmseg_evaluate(model, image, target)
                output = model.encode_decode(image, None)

            self.metric.update(output, target)

            if (i == 0):
                import cv2
                image_read = cv2.imread(filename[0])
                writer.add_image("Image[0] Read",
                                 image_read,
                                 epoch,
                                 dataformats="HWC")

                save_imgs = torch.softmax(output, dim=1)[0]
                for class_no, class_distri in enumerate(save_imgs):
                    plt.clf()
                    class_distri[0][0] = 0
                    class_distri[0][1] = 1

                    im = plt.imshow(class_distri.detach().cpu().numpy(),
                                    cmap="Greens")
                    plt.colorbar(im)
                    plt.savefig("temp_files/temp.jpg")
                    plt.clf()
                    import cv2
                    img_dif = cv2.imread("temp_files/temp.jpg")

                    writer.add_image(f"Class_{self.classes[class_no]}",
                                     img_dif,
                                     epoch,
                                     dataformats="HWC")

            with torch.no_grad():
                self.ece_evaluator.forward(output, target)
                self.cce_evaluator.forward(output, target)

            pixAcc, mIoU = self.metric.get()
            logging.info(
                "[EVAL] Sample: {:d}, pixAcc: {:.3f}, mIoU: {:.3f}".format(
                    i + 1, pixAcc * 100, mIoU * 100))
        pixAcc, mIoU = self.metric.get()
        logging.info(
            "[EVAL END] Epoch: {:d}, pixAcc: {:.3f}, mIoU: {:.3f}".format(
                epoch, pixAcc * 100, mIoU * 100))

        writer.add_scalar("[EVAL END] pixAcc", pixAcc * 100, epoch)
        writer.add_scalar("[EVAL END] mIoU", mIoU * 100, epoch)

        ece_count_table_image, _ = self.ece_evaluator.get_count_table_img(
            self.classes)
        ece_table_image, ece_dif_map = self.ece_evaluator.get_perc_table_img(
            self.classes)

        cce_count_table_image, _ = self.cce_evaluator.get_count_table_img(
            self.classes)
        cce_table_image, cce_dif_map = self.cce_evaluator.get_perc_table_img(
            self.classes)
        ece_dif_mean, ece_dif_std = self.ece_evaluator.get_diff_mean_std()
        cce_dif_mean, cce_dif_std = self.cce_evaluator.get_diff_mean_std()
        writer.add_image("ece_table",
                         ece_table_image,
                         epoch,
                         dataformats="HWC")
        writer.add_image("ece Count table",
                         ece_count_table_image,
                         epoch,
                         dataformats="HWC")
        writer.add_image("ece DifMap", ece_dif_map, epoch, dataformats="HWC")

        writer.add_scalar("ece_mean", ece_dif_mean, epoch)
        writer.add_scalar("ece_std", ece_dif_std, epoch)
        writer.add_scalar("ece Score",
                          self.ece_evaluator.get_overall_ECELoss(), epoch)
        writer.add_scalar("ece dif Score", self.ece_evaluator.get_diff_score(),
                          epoch)

        writer.add_image("cce_table",
                         cce_table_image,
                         epoch,
                         dataformats="HWC")
        writer.add_image("cce Count table",
                         cce_count_table_image,
                         epoch,
                         dataformats="HWC")
        writer.add_image("cce DifMap", cce_dif_map, epoch, dataformats="HWC")

        cces = self.cce_evaluator.get_overall_CCELoss()

        writer.add_scalar("cce_mean", cce_dif_mean, epoch)
        writer.add_scalar("cce_std", cce_dif_std, epoch)
        writer.add_scalar("cce Score", cces, epoch)
        writer.add_scalar("cce dif Score", self.cce_evaluator.get_diff_score(),
                          epoch)

        synchronize()
        if self.best_pred_miou < mIoU and self.save_to_disk:
            self.best_pred_miou = mIoU
            logging.info(
                'Epoch {} is the best model for mIoU, best pixAcc: {:.3f}, mIoU: {:.3f}, save the model..'
                .format(epoch, pixAcc * 100, mIoU * 100))
            save_checkpoint(model, epoch, is_best=True, mode="iou")
        if self.best_pred_cces > cces and self.save_to_disk:
            self.best_pred_cces = cces
            logging.info(
                'Epoch {} is the best model for cceScore, best pixAcc: {:.3f}, mIoU: {:.3f}, save the model..'
                .format(epoch, pixAcc * 100, mIoU * 100))
            save_checkpoint(model, epoch, is_best=True, mode="cces")
예제 #8
0
    def giveComparisionImages_colormaps(self, pre_output, post_output,
                                        raw_image, gt_label, classes, outname):
        """
        pre_output-> [1,21,h,w] cuda tensor
        post_output-> [1,21,h,w] cuda tensor
        raw_image->[1,3,h,w] cuda tensor
        gt_label->[1,h,w] cuda tensor
        """
        metric = SegmentationMetric(nclass=21, distributed=False)
        metric.update(pre_output, gt_label)
        pre_pixAcc, pre_mIoU = metric.get()

        metric = SegmentationMetric(nclass=21, distributed=False)
        metric.update(post_output, gt_label)
        post_pixAcc, post_mIoU = metric.get()

        uncal_labels = np.unique(
            torch.argmax(pre_output.squeeze(0), dim=0).cpu().numpy())
        cal_labels = np.unique(
            torch.argmax(post_output.squeeze(0), dim=0).cpu().numpy())

        pre_label_map = torch.argmax(pre_output.squeeze(0),
                                     dim=0).cpu().numpy()
        post_label_map = torch.argmax(post_output.squeeze(0),
                                      dim=0).cpu().numpy()

        # Bringing the shapes to justice
        pre_output = pre_output.squeeze(0).cpu().numpy()
        post_output = post_output.squeeze(0).cpu().numpy()
        raw_image = raw_image.squeeze(0).permute(1, 2, 0).cpu().numpy().astype(
            np.uint8)
        gt_label = gt_label.squeeze(0).cpu().numpy()

        if False:
            pass
        else:
            # Show result for each class
            cols = int(np.ceil(
                (max(len(uncal_labels), len(cal_labels)) + 1))) + 1
            rows = 4

            plt.figure(figsize=(20, 20))
            # Plotting raw image
            ax = plt.subplot(rows, cols, 1)
            ax.set_title("Input image")
            ax.imshow(raw_image[:, :, ::-1])
            ax.axis("off")

            # Plottig GT
            ax = plt.subplot(rows, cols, cols + 1)
            ax.set_title("Difference MAP ")
            mask1 = get_color_pallete(pre_label_map, cfg.DATASET.NAME)
            mask2 = get_color_pallete(post_label_map, cfg.DATASET.NAME)
            # print(raw_image[:, :, ::-1].shape)
            ax.imshow(((pre_label_map != post_label_map).astype(np.uint8)))
            ax.axis("off")

            # Plottig GT
            ax = plt.subplot(rows, cols, 2 * cols + 1)
            ax.set_title("ColorMap (uncal+crf) pixA={:.4f} mIoU={:.4f}".format(
                pre_pixAcc, pre_mIoU))

            mask = get_color_pallete(pre_label_map, cfg.DATASET.NAME)
            ax.imshow(np.array(mask))
            ax.axis("off")

            # Plottig GT
            ax = plt.subplot(rows, cols, 3 * cols + 1)
            # metric = SegmentationMetric(nclass=21, distributed=False)
            # metric.update(pre_output, gt_label)
            # pixAcc, mIoU = metric.get()
            ax.set_title(
                "ColorMap (cal T = {} + CRF) pixA={:.4f} mIoU={:.4f}".format(
                    self.temp, post_pixAcc, post_mIoU))
            mask = get_color_pallete(post_label_map, cfg.DATASET.NAME)
            ax.imshow(np.array(mask))
            ax.axis("off")

            for i, label in enumerate(uncal_labels):
                ax = plt.subplot(rows, cols, i + 3)
                ax.set_title("Uncalibrated-" + classes[label])
                ax.imshow(pre_output[label], cmap="nipy_spectral")

                ax.axis("off")

            for i, label in enumerate(cal_labels):
                ax = plt.subplot(rows, cols, cols + i + 3)
                ax.set_title("Calibrated-" + classes[label])
                ax.imshow(post_output[label], cmap="nipy_spectral")
                ax.axis("off")

            for i, label in enumerate(cal_labels):
                ax = plt.subplot(rows, cols, 2 * cols + i + 3)

                min_dif = np.min(pre_output[label] - post_output[label])
                max_dif = np.max(pre_output[label] - post_output[label])

                dif_map = np.where(
                    (pre_output[label] - post_output[label]) > 0,
                    (pre_output[label] - post_output[label]),
                    0,
                )

                ax.set_title("decrease: " + classes[label] +
                             " max={:0.3f}".format(max_dif))
                ax.imshow(
                    dif_map / max_dif,
                    cmap="nipy_spectral",
                )
                ax.axis("off")

            for i, label in enumerate(cal_labels):
                ax = plt.subplot(rows, cols, 3 * cols + i + 3)

                min_dif = np.min(pre_output[label] - post_output[label])
                max_dif = np.max(pre_output[label] - post_output[label])

                dif_map = np.where(
                    (pre_output[label] - post_output[label]) < 0,
                    (pre_output[label] - post_output[label]),
                    0,
                )

                ax.set_title("increase: " + classes[label] +
                             " max={:0.3f}".format(-min_dif))
                ax.imshow(
                    dif_map / min_dif,
                    cmap="nipy_spectral",
                )
                ax.axis("off")

            plt.tight_layout()
            plt.savefig(outname)
예제 #9
0
class Evaluator(object):
    def __init__(self, args):

        self.args = args
        self.device = torch.device(args.device)

        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(cfg.DATASET.MEAN, cfg.DATASET.STD),
        ])

        # dataset and dataloader
        val_dataset = get_segmentation_dataset(cfg.DATASET.NAME,
                                               split='val',
                                               mode='testval',
                                               transform=input_transform)
        val_sampler = make_data_sampler(val_dataset, False, args.distributed)
        val_batch_sampler = make_batch_data_sampler(
            val_sampler, images_per_batch=cfg.TEST.BATCH_SIZE, drop_last=False)
        self.val_loader = data.DataLoader(dataset=val_dataset,
                                          batch_sampler=val_batch_sampler,
                                          num_workers=cfg.DATASET.WORKERS,
                                          pin_memory=True)

        self.dataset = val_dataset
        self.classes = val_dataset.classes
        self.metric = SegmentationMetric(val_dataset.num_class,
                                         args.distributed)

        # DEFINE data for noisy
        val_dataset_noisy = get_segmentation_dataset(cfg.DATASET.NOISY_NAME,
                                                     split='val',
                                                     mode='testval',
                                                     transform=input_transform)
        val_sampler_noisy = make_data_sampler(val_dataset_noisy, False,
                                              args.distributed)
        val_batch_sampler_noisy = make_batch_data_sampler(
            val_sampler_noisy,
            images_per_batch=cfg.TEST.BATCH_SIZE,
            drop_last=False)
        self.val_loader_noisy = data.DataLoader(
            dataset=val_dataset_noisy,
            batch_sampler=val_batch_sampler_noisy,
            num_workers=cfg.DATASET.WORKERS,
            pin_memory=True)

        self.model = get_segmentation_model().to(self.device)

        if hasattr(self.model, 'encoder') and hasattr(self.model.encoder, 'named_modules') and \
            cfg.MODEL.BN_EPS_FOR_ENCODER:
            logging.info('set bn custom eps for bn in encoder: {}'.format(
                cfg.MODEL.BN_EPS_FOR_ENCODER))
            self.set_batch_norm_attr(self.model.encoder.named_modules(), 'eps',
                                     cfg.MODEL.BN_EPS_FOR_ENCODER)

        if args.distributed:
            self.model = nn.parallel.DistributedDataParallel(
                self.model,
                device_ids=[args.local_rank],
                output_device=args.local_rank,
                find_unused_parameters=True)

        self.model.to(self.device)

    def set_batch_norm_attr(self, named_modules, attr, value):
        for m in named_modules:
            if isinstance(m[1], nn.BatchNorm2d) or isinstance(
                    m[1], nn.SyncBatchNorm):
                setattr(m[1], attr, value)

    @torch.no_grad()
    def eval(self, val_loader, crf):
        self.metric.reset()
        self.model.eval()
        model = self.model

        logging.info("Start validation, Total sample: {:d}".format(
            len(val_loader)))
        import time
        time_start = time.time()

        for (image, target, filename) in tqdm(val_loader):
            image = image.to(self.device)
            target = target.to(self.device)

            # print(image.shape)

            output = model.evaluate(image)

            # using CRF --------------------
            filename = filename[0]
            # print(filename)
            raw_image = cv2.imread(filename, cv2.IMREAD_COLOR).astype(np.uint8)

            probmap = torch.softmax(output, dim=1)[0].cpu().numpy()
            # print(probmap.shape)
            output = crf(raw_image, probmap)

            # put numpy back on gpu since target is on gpu
            output = torch.from_numpy(output).cuda().unsqueeze(dim=0)
            # ---------------------------

            # print(output.shape)
            self.metric.update(output, target)
            # pixAcc, mIoU = self.metric.get()

        pixAcc, mIoU, category_iou = self.metric.get(return_category_iou=True)
        logging.info('Eval use time: {:.3f} second'.format(time.time() -
                                                           time_start))
        logging.info('End validation pixAcc: {:.3f}, mIoU: {:.3f}'.format(
            pixAcc * 100, mIoU * 100))

        return pixAcc * 100, mIoU * 100
예제 #10
0
class Trainer(object):
    def __init__(self, args):
        self.args = args
        self.device = torch.device(args.device)

        # self.prefix = "randomtesting"
        # self.prefix = "cce_alpha={}_sigma=0.5".format(cfg.TRAIN.ALPHA)
        self.prefix = "temp fcn scaling testing full={}".format(cfg.TRAIN.ALPHA)
        self.writer = SummaryWriter(log_dir= f"CCE_train_both_eval/{self.prefix}")
        self.writer_noisy = SummaryWriter(log_dir= f"CCE_train_both_eval/{self.prefix}-foggy")

        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(cfg.DATASET.MEAN, cfg.DATASET.STD),
        ])
        # dataset and dataloader
        data_kwargs = {'transform': input_transform, 'base_size': cfg.TRAIN.BASE_SIZE,
                       'crop_size': cfg.TRAIN.CROP_SIZE}

        # OUR TRAIN AND VAL ARE ALWAYS NORMAL CITYSCAPES
        train_dataset = get_segmentation_dataset(cfg.DATASET.NAME, split='train', mode='train', **data_kwargs)
        val_dataset = get_segmentation_dataset(cfg.DATASET.NAME, split='val', mode=cfg.DATASET.MODE, **data_kwargs)
        self.classes = val_dataset.classes
        self.iters_per_epoch = len(train_dataset) // (args.num_gpus * cfg.TRAIN.BATCH_SIZE)
        self.max_iters = cfg.TRAIN.EPOCHS * self.iters_per_epoch

        self.ece_evaluator = ECELoss(n_classes = len(self.classes))
        self.cce_evaluator = CCELoss(n_classes = len(self.classes))

        train_sampler = make_data_sampler(train_dataset, shuffle=True, distributed=args.distributed)
        train_batch_sampler = make_batch_data_sampler(train_sampler, cfg.TRAIN.BATCH_SIZE, self.max_iters, drop_last=True)

        val_sampler = make_data_sampler(val_dataset, False, args.distributed)
        val_batch_sampler = make_batch_data_sampler(val_sampler, cfg.TEST.BATCH_SIZE, drop_last=False)

        self.train_loader = data.DataLoader(dataset=train_dataset,
                                            batch_sampler=train_batch_sampler,
                                            num_workers=cfg.DATASET.WORKERS,
                                            pin_memory=True)

        self.val_loader = data.DataLoader(dataset=val_dataset,
                                          batch_sampler=val_batch_sampler,
                                          num_workers=cfg.DATASET.WORKERS,
                                          pin_memory=True)

        # DEFINE data for noisy
        val_dataset_noisy = get_segmentation_dataset(cfg.DATASET.NOISY_NAME, split='val', mode=cfg.DATASET.MODE, **data_kwargs)
        self.val_loader_noisy = data.DataLoader(dataset=val_dataset_noisy,
                                          batch_sampler=val_batch_sampler,
                                          num_workers=cfg.DATASET.WORKERS,
                                          pin_memory=True)

        # create network
        self.model = get_segmentation_model().to(self.device)

        # print (self.model)
        # import pdb; pdb,set_trace()
        for params in self.model.encoder.parameters():
            params.requires_grad = False

        
        # print params and flops
        if get_rank() == 0:
            try:
                show_flops_params(copy.deepcopy(self.model), args.device)
            except Exception as e:
                logging.warning('get flops and params error: {}'.format(e))

        if cfg.MODEL.BN_TYPE not in ['BN']:
            logging.info('Batch norm type is {}, convert_sync_batchnorm is not effective'.format(cfg.MODEL.BN_TYPE))
        elif args.distributed and cfg.TRAIN.SYNC_BATCH_NORM:
            self.model = nn.SyncBatchNorm.convert_sync_batchnorm(self.model)
            logging.info('SyncBatchNorm is effective!')
        else:
            logging.info('Not use SyncBatchNorm!')

        # create criterion
        self.criterion = get_segmentation_loss(cfg.MODEL.MODEL_NAME, use_ohem=cfg.SOLVER.OHEM,
                                               aux=cfg.SOLVER.AUX, aux_weight=cfg.SOLVER.AUX_WEIGHT,
                                               ignore_index=cfg.DATASET.IGNORE_INDEX, n_classes=len(train_dataset.classes), alpha=cfg.TRAIN.ALPHA).to(self.device)

        # optimizer, for model just includes encoder, decoder(head and auxlayer).
        self.optimizer = get_optimizer(self.model)

        # lr scheduling
        self.lr_scheduler = get_scheduler(self.optimizer, max_iters=self.max_iters,
                                          iters_per_epoch=self.iters_per_epoch)

        # resume checkpoint if needed
        self.start_epoch = 0
        if args.resume and os.path.isfile(args.resume):
            name, ext = os.path.splitext(args.resume)
            assert ext == '.pkl' or '.pth', 'Sorry only .pth and .pkl files supported.'
            logging.info('Resuming training, loading {}...'.format(args.resume))
            resume_sate = torch.load(args.resume)
            self.model.load_state_dict(resume_sate['state_dict'])
            self.start_epoch = resume_sate['epoch']
            logging.info('resume train from epoch: {}'.format(self.start_epoch))
            if resume_sate['optimizer'] is not None and resume_sate['lr_scheduler'] is not None:
                logging.info('resume optimizer and lr scheduler from resume state..')
                self.optimizer.load_state_dict(resume_sate['optimizer'])
                self.lr_scheduler.load_state_dict(resume_sate['lr_scheduler'])

        if args.distributed:
            self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[args.local_rank],
                                                             output_device=args.local_rank,
                                                             find_unused_parameters=True)

        # evaluation metrics
        self.metric = SegmentationMetric(train_dataset.num_class, args.distributed)
        self.best_pred = 0.0


    def train(self):
        self.save_to_disk = get_rank() == 0
        epochs, max_iters, iters_per_epoch = cfg.TRAIN.EPOCHS, self.max_iters, self.iters_per_epoch
        log_per_iters, val_per_iters = self.args.log_iter, self.args.val_epoch * self.iters_per_epoch

        start_time = time.time()
        logging.info('Start training, Total Epochs: {:d} = Total Iterations {:d}'.format(epochs, max_iters))

        self.model.train()
        iteration = self.start_epoch * iters_per_epoch if self.start_epoch > 0 else 0

        total_loss = 0
        for (images, targets, _) in self.train_loader:
            epoch = iteration // iters_per_epoch + 1
            iteration += 1

            images = images.to(self.device)
            targets = targets.to(self.device)

            outputs = self.model(images)
            loss_dict, loss_cal, loss_nll  = self.criterion(outputs, targets)

            losses = sum(loss for loss in loss_dict.values())

            # reduce losses over all GPUs for logging purposes
            loss_dict_reduced = reduce_loss_dict(loss_dict)
            losses_reduced = sum(loss for loss in loss_dict_reduced.values())

            total_loss += losses_reduced.item()

            self.optimizer.zero_grad()
            losses.backward()
            self.optimizer.step()
            self.lr_scheduler.step()

            eta_seconds = ((time.time() - start_time) / iteration) * (max_iters - iteration)
            eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))

            if iteration % log_per_iters == 0 and self.save_to_disk:
                logging.info(
                    "Epoch: {:d}/{:d} || Iters: {:d}/{:d} || Lr: {:.6f} || "
                    "Loss: {:.4f} || Cost Time: {} || Estimated Time: {}".format(
                        epoch, epochs, iteration % iters_per_epoch, iters_per_epoch,
                        self.optimizer.param_groups[0]['lr'], losses_reduced.item(),
                        str(datetime.timedelta(seconds=int(time.time() - start_time))),
                        eta_string))

                self.writer.add_scalar("Loss", losses_reduced.item(), iteration )
                self.writer.add_scalar("CCE oe ECE part of Loss", loss_cal.item(), iteration )
                self.writer.add_scalar("NLL Part", loss_nll.item(), iteration )
                self.writer.add_scalar("LR", self.optimizer.param_groups[0]['lr'], iteration )

            if iteration % self.iters_per_epoch == 0 and self.save_to_disk:
                save_checkpoint(self.model, epoch, self.optimizer, self.lr_scheduler, is_best=False)

            if not self.args.skip_val and iteration % val_per_iters == 0:
                self.validation(epoch, self.val_loader, self.writer)
                self.validation(epoch, self.val_loader_noisy, self.writer_noisy)
                self.model.train()

        total_training_time = time.time() - start_time
        total_training_str = str(datetime.timedelta(seconds=total_training_time))
        logging.info(
            "Total training time: {} ({:.4f}s / it)".format(
                total_training_str, total_training_time / max_iters))

    def validation(self, val_loader, writer, temp = 1):
        self.metric.reset()
        self.ece_evaluator.reset()
        self.cce_evaluator.reset()
        eceEvaluator_perimage = perimageCCE(n_classes = len(self.classes))

        if self.args.distributed:
            model = self.model.module
        else:
            model = self.model
        torch.cuda.empty_cache()
        model.eval()

        for i, (image, target, filename) in enumerate(val_loader):
            image = image.to(self.device)
            target = target.to(self.device)


            with torch.no_grad():
                if cfg.DATASET.MODE == 'val' or cfg.TEST.CROP_SIZE is None:
                    output = model(image)[0]
                else:
                    size = image.size()[2:]
                    pad_height = cfg.TEST.CROP_SIZE[0] - size[0]
                    pad_width = cfg.TEST.CROP_SIZE[1] - size[1]
                    image = F.pad(image, (0, pad_height, 0, pad_width))
                    output = model(image)[0]
                    output = output[..., :size[0], :size[1]]

            # output is [1, 19, 1024, 2048] logits
            # target is [1, 1024, 2048]

            output /= temp
            self.metric.update(output, target)

            if (i==0):
                import cv2
                image_read = cv2.imread(filename[0])
                writer.add_image("Image[0] Read", image_read, temp*10, dataformats="HWC")

                save_imgs = torch.softmax(output, dim =1)[0]
                for class_no, class_distri in enumerate(save_imgs):
                    plt.clf()
                    class_distri[0][0] = 0
                    class_distri[0][1] = 1

                    im = plt.imshow(class_distri.detach().cpu().numpy(),cmap="Greens")
                    plt.colorbar(im)
                    plt.savefig("temp_files/temp.jpg")
                    plt.clf()
                    import cv2
                    img_dif = cv2.imread("temp_files/temp.jpg")

                    writer.add_image(f"Class_{self.classes[class_no]}", img_dif, temp*10, dataformats="HWC")
            
            with torch.no_grad():
                self.ece_evaluator.forward(output,target)
                self.cce_evaluator.forward(output,target)
                # for output, target in zip(output,target.detach()):
                #     #older ece requires softmax and size output=[class,w,h] target=[w,h]
                #     eceEvaluator_perimage.update(output.softmax(dim=0), target)                

            pixAcc, mIoU = self.metric.get()
            logging.info("[EVAL] Sample: {:d}, pixAcc: {:.3f}, mIoU: {:.3f}".format(i + 1, pixAcc * 100, mIoU * 100))
        
        pixAcc, mIoU = self.metric.get()
        logging.info("[EVAL END] temp: {:f}, pixAcc: {:.3f}, mIoU: {:.3f}".format(temp, pixAcc * 100, mIoU * 100))
        writer.add_scalar("Temp [EVAL END] pixAcc", pixAcc * 100, temp*10 )
        writer.add_scalar("Temp [EVAL END] mIoU", mIoU * 100, temp*10 )

        # count_table_image, _ = eceEvaluator_perimage.get_count_table_img(self.classes)
        # cce_table_image, dif_map = eceEvaluator_perimage.get_perc_table_img(self.classes)

        ece_count_table_image, _ = self.ece_evaluator.get_count_table_img(self.classes)
        ece_table_image, ece_dif_map = self.ece_evaluator.get_perc_table_img(self.classes)
        
        cce_count_table_image, _ = self.cce_evaluator.get_count_table_img(self.classes)
        cce_table_image, cce_dif_map = self.cce_evaluator.get_perc_table_img(self.classes)
        
        # writer.add_image("Temp CCE_table", cce_table_image, temp, dataformats="HWC")
        # writer.add_image("Temp CCE Count table", count_table_image, temp, dataformats="HWC")
        # writer.add_image("Temp CCE DifMap", dif_map, temp, dataformats="HWC")
        # writer.add_scalar("Temp CCE Score", eceEvaluator_perimage.get_overall_CCELoss(), temp)

        writer.add_image("Temp ece_table", ece_table_image, temp*10, dataformats="HWC")
        writer.add_image("Temp ece Count table", ece_count_table_image, temp*10, dataformats="HWC")
        writer.add_image("Temp ece DifMap", ece_dif_map, temp*10, dataformats="HWC")
        writer.add_scalar("Temp ece Score", self.ece_evaluator.get_overall_ECELoss(), temp*10)

        writer.add_image("Temp cce_table", cce_table_image, temp*10, dataformats="HWC")
        writer.add_image("Temp cce Count table", cce_count_table_image, temp*10, dataformats="HWC")
        writer.add_image("Temp cce DifMap", cce_dif_map, temp*10, dataformats="HWC")
        writer.add_scalar("Temp cce Score", self.cce_evaluator.get_overall_CCELoss(), temp*10)
        synchronize()
        # if self.best_pred < mIoU and self.save_to_disk:
        #     self.best_pred = mIoU
        #     logging.info('Epoch {} is the best model, best pixAcc: {:.3f}, mIoU: {:.3f}, save the model..'.format(temp*10, pixAcc * 100, mIoU * 100))
        #     save_checkpoint(model, temp*10, is_best=True)
    def temp_scaling_validation(self, temp = 1):
        self.validation( self.val_loader, self.writer, temp = temp)
        self.validation( self.val_loader_noisy, self.writer_noisy, temp = temp)
예제 #11
0
class Evaluator(object):
    def __init__(self, args):

        self.args = args
        self.device = torch.device(args.device)

        self.n_bins = 15
        self.ece_folder = "eceData"
        # self.postfix="foggy_conv13_CityScapes_GPU"
        self.postfix = "foggy_zurich_conv13"
        # self.postfix="Foggy_1_conv13_PascalVOC_GPU"
        self.temp = 1.5
        # self.useCRF=False
        self.useCRF = True

        self.ece_criterion = metrics.IterativeECELoss()
        self.ece_criterion.make_bins(n_bins=self.n_bins)

        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(cfg.DATASET.MEAN, cfg.DATASET.STD),
        ])

        # dataset and dataloader
        val_dataset = get_segmentation_dataset(cfg.DATASET.NAME,
                                               split='val',
                                               mode='testval',
                                               transform=input_transform)
        val_sampler = make_data_sampler(val_dataset, False, args.distributed)
        val_batch_sampler = make_batch_data_sampler(
            val_sampler, images_per_batch=cfg.TEST.BATCH_SIZE, drop_last=False)
        self.val_loader = data.DataLoader(dataset=val_dataset,
                                          batch_sampler=val_batch_sampler,
                                          num_workers=cfg.DATASET.WORKERS,
                                          pin_memory=True)

        self.dataset = val_dataset
        self.classes = val_dataset.classes
        print(args.distributed)
        self.metric = SegmentationMetric(val_dataset.num_class,
                                         args.distributed)

        self.model = get_segmentation_model().to(self.device)

        if hasattr(self.model, 'encoder') and hasattr(self.model.encoder, 'named_modules') and \
            cfg.MODEL.BN_EPS_FOR_ENCODER:
            logging.info('set bn custom eps for bn in encoder: {}'.format(
                cfg.MODEL.BN_EPS_FOR_ENCODER))
            self.set_batch_norm_attr(self.model.encoder.named_modules(), 'eps',
                                     cfg.MODEL.BN_EPS_FOR_ENCODER)

        if args.distributed:
            self.model = nn.parallel.DistributedDataParallel(
                self.model,
                device_ids=[args.local_rank],
                output_device=args.local_rank,
                find_unused_parameters=True)

        self.model.to(self.device)

    def set_batch_norm_attr(self, named_modules, attr, value):
        for m in named_modules:
            if isinstance(m[1], nn.BatchNorm2d) or isinstance(
                    m[1], nn.SyncBatchNorm):
                setattr(m[1], attr, value)

    def eceOperations(self, bin_total, bin_total_correct, bin_conf_total):
        eceLoss = self.ece_criterion.get_interative_loss(
            bin_total, bin_total_correct, bin_conf_total)
        print('ECE with probabilties %f' % (eceLoss))

        saveDir = os.path.join(self.ece_folder, self.postfix)
        makedirs(saveDir)

        file = open(os.path.join(saveDir, "Results.txt"), "a")
        file.write(
            f"{self.postfix}_temp={self.temp}\t\t\t ECE Loss: {eceLoss}\n")

        plot_folder = os.path.join(saveDir, "plots")
        makedirs(plot_folder)

        rel_diagram = visualization.ReliabilityDiagramIterative()
        plt_test_2 = rel_diagram.plot(bin_total,
                                      bin_total_correct,
                                      bin_conf_total,
                                      title="Reliability Diagram")
        plt_test_2.savefig(os.path.join(plot_folder,
                                        f'rel_diagram_temp={self.temp}.png'),
                           bbox_inches='tight')

    def eval(self):
        self.metric.reset()
        self.model.eval()
        model = self.model

        logging.info("Start validation, Total sample: {:d}".format(
            len(self.val_loader)))
        import time
        time_start = time.time()
        # if(not self.useCRF):
        bin_total = []
        bin_total_correct = []
        bin_conf_total = []
        for (image, target, filename) in tqdm(self.val_loader):
            image = image.to(self.device)
            target = target.to(self.device)

            # print(image.shape)
            with torch.no_grad():
                output = model.evaluate(image)
                output /= self.temp
                output_for_ece = output.clone()

                # if use CRF
                if (self.useCRF):
                    filename = filename[0]
                    raw_image = cv2.imread(filename, cv2.IMREAD_COLOR).astype(
                        np.float32).transpose(2, 0, 1)
                    raw_image = torch.from_numpy(raw_image).to(self.device)
                    raw_image = raw_image.unsqueeze(dim=0)
                    crf = GaussCRF(conf=get_default_conf(),
                                   shape=image.shape[2:],
                                   nclasses=len(self.classes),
                                   use_gpu=True)
                    crf = crf.to(self.device)
                    # print(image.shape,raw_image.shape)
                    assert image.shape == raw_image.shape
                    output = crf.forward(output, raw_image)

            # ECE Stuff
            conf = np.max(output_for_ece.softmax(dim=1).cpu().numpy(), axis=1)
            label = torch.argmax(output_for_ece, dim=1).cpu().numpy()
            # print(conf.shape,label.shape,target.shape)
            bin_total_current, bin_total_correct_current, bin_conf_total_current = self.ece_criterion.get_collective_bins(
                conf, label,
                target.cpu().numpy())
            # import pdb; pdb.set_trace()
            bin_total.append(bin_total_current)
            bin_total_correct.append(bin_total_correct_current)
            bin_conf_total.append(bin_conf_total_current)

            # Accuracy Stuff
            self.metric.update(output, target)
            pixAcc, mIoU = self.metric.get()

        # ECE stuff
        # if(not self.useCRF):
        self.eceOperations(bin_total, bin_total_correct, bin_conf_total)

        # Accuracy stuff
        pixAcc, mIoU, category_iou = self.metric.get(return_category_iou=True)
        logging.info('Eval use time: {:.3f} second'.format(time.time() -
                                                           time_start))
        # file=open("foggy_1_conv13_VOC.txt","a")
        file = open(f"{self.postfix}.txt", "a")
        file.write("Temp={} + crf\n".format(self.temp))
        file.write('End validation pixAcc: {:.3f}, mIoU: {:.3f}'.format(
            pixAcc * 100, mIoU * 100))

        file.write("\n\n")
        file.close()

        logging.info('End validation pixAcc: {:.3f}, mIoU: {:.3f}'.format(
            pixAcc * 100, mIoU * 100))
예제 #12
0
class Evaluator(object):
    def __init__(self, args, temp=1):

        self.args = args
        self.device = torch.device(args.device)

        self.args = args
        self.device = torch.device(args.device)

        self.n_bins = 10
        self.ece_folder = "experiments/classCali/eceData"
        # self.postfix = "Conv13_PascalVOC_GPU"
        # self.postfix = "Min_Foggy_1_conv13_PascalVOC_GPU"
        # self.postfix = "Foggy_1_conv13_PascalVOC_GPU"
        # self.postfix = "FoggyCityscapes_conv13_exp"
        self.postfix = "foggy_zurich_conv13"
        self.temp = temp
        print("Current temp being used : {}".format(self.temp))
        self.showProbMaps = False
        # self.useCRF=False
        self.useCRF = True

        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(cfg.DATASET.MEAN, cfg.DATASET.STD),
        ])

        # dataset and dataloader
        val_dataset = get_segmentation_dataset(cfg.DATASET.NAME,
                                               split='val',
                                               mode='testval',
                                               transform=input_transform)
        val_sampler = make_data_sampler(val_dataset, False, args.distributed)
        val_batch_sampler = make_batch_data_sampler(
            val_sampler, images_per_batch=cfg.TEST.BATCH_SIZE, drop_last=False)
        self.val_loader = data.DataLoader(dataset=val_dataset,
                                          batch_sampler=val_batch_sampler,
                                          num_workers=cfg.DATASET.WORKERS,
                                          pin_memory=True)

        self.dataset = val_dataset
        self.classes = val_dataset.classes
        self.metric = SegmentationMetric(val_dataset.num_class,
                                         args.distributed)
        self.uncal_metric = SegmentationMetric(val_dataset.num_class,
                                               args.distributed)
        self.cal_metric = SegmentationMetric(val_dataset.num_class,
                                             args.distributed)

        self.model = get_segmentation_model().to(self.device)

        if hasattr(self.model, 'encoder') and hasattr(self.model.encoder, 'named_modules') and \
            cfg.MODEL.BN_EPS_FOR_ENCODER:
            logging.info('set bn custom eps for bn in encoder: {}'.format(
                cfg.MODEL.BN_EPS_FOR_ENCODER))
            self.set_batch_norm_attr(self.model.encoder.named_modules(), 'eps',
                                     cfg.MODEL.BN_EPS_FOR_ENCODER)

        if args.distributed:
            self.model = nn.parallel.DistributedDataParallel(
                self.model,
                device_ids=[args.local_rank],
                output_device=args.local_rank,
                find_unused_parameters=True)

        self.model.to(self.device)

    def set_batch_norm_attr(self, named_modules, attr, value):
        for m in named_modules:
            if isinstance(m[1], nn.BatchNorm2d) or isinstance(
                    m[1], nn.SyncBatchNorm):
                setattr(m[1], attr, value)

    def eval(self):
        self.metric.reset()
        self.model.eval()
        model = self.model

        logging.info("Start validation, Total sample: {:d}".format(
            len(self.val_loader)))
        import time
        eceEvaluator = CCELoss(n_classes=len(self.classes))
        time_start = time.time()
        for (image, target, filename) in tqdm(self.val_loader):
            image = image.to(self.device)
            target = target.to(self.device)

            # print(image.shape)
            with torch.no_grad():
                output = model.evaluate(image)
                # output = torch.softmax(output, dim=1)
                # output =output*0 + 1/len(self.classes)

                output_uncal = output.clone()

                # output_cal = torch.log(torch.softmax(output, dim=1))
                output_cal = output / self.temp

                # if use CRF
                if (self.useCRF):
                    filename = filename[0]
                    # print(filename)
                    raw_image = cv2.imread(filename, cv2.IMREAD_COLOR).astype(
                        np.float32).transpose(2, 0, 1)
                    raw_image = torch.from_numpy(raw_image).to(self.device)
                    raw_image = raw_image.unsqueeze(dim=0)
                    crf = GaussCRF(conf=get_default_conf(),
                                   shape=image.shape[2:],
                                   nclasses=len(self.classes),
                                   use_gpu=True)
                    crf = crf.to(self.device)
                    assert image.shape == raw_image.shape

                    output_cal_crf = crf.forward(output_cal, raw_image)
                    output_uncal_crf = crf.forward(output_uncal, raw_image)
                    comparisionFolder = "experiments/classCali/comparisionImages"
                    saveFolder = os.path.join(
                        comparisionFolder, self.postfix + f"_temp={self.temp}")
                    makedirs(saveFolder)
                    saveName = os.path.join(saveFolder,
                                            os.path.basename(filename))

                    # image_pil = Image.open(filename).convert('RGB')
                    # after_crf_uncal_mask = get_color_pallete(torch.argmax(output_uncal_crf, dim=1).squeeze(0).cpu().numpy(), cfg.DATASET.NAME)
                    # after_crf_cal_mask = get_color_pallete(torch.argmax(output_cal_crf, dim=1).squeeze(0).cpu().numpy(), cfg.DATASET.NAME)
                    # difference_map_mask = get_color_pallete((torch.argmax(output_uncal_crf, dim=1).squeeze(0).cpu().numpy() != torch.argmax(output_cal_crf, dim=1).squeeze(0).cpu().numpy())*19, cfg.DATASET.NAME)
                    # gt_mask = get_color_pallete(target.squeeze(0).cpu().numpy(), cfg.DATASET.NAME)

                    # # Concatenating horizontally [out_post_crf,out_pre_crf, rgb]
                    # dst = Image.new('RGB', (5*gt_mask.width+12, gt_mask.height), color="white")
                    # # dst = Image.new('RGB', (4*gt_mask.width+9, gt_mask.height), color="white")
                    # dst.paste(after_crf_uncal_mask, (0, 0))
                    # dst.paste(after_crf_cal_mask, (gt_mask.width+3, 0))
                    # dst.paste(gt_mask, (2*gt_mask.width+6, 0))
                    # dst.paste(image_pil, (3*gt_mask.width+9, 0))
                    # dst.paste(difference_map_mask, (4*gt_mask.width+ 12, 0))
                    # dst.save(saveName)

                # import pdb; pdb.set_trace()

                eceEvaluator.update(
                    output_cal.softmax(dim=1).squeeze(0), target.squeeze(0))
                # print(eceEvaluator.get_perc_table())

            if (self.useCRF):
                self.cal_metric.update(output_cal_crf, target)
                self.uncal_metric.update(output_uncal_crf, target)
            else:
                self.metric.update(output_cal, target)
                pixAcc, mIoU = self.metric.get()

        eceEvaluator.get_perc_table(self.classes)
        overallLoss = eceEvaluator.get_overall_CCELoss()
        eceEvaluator.get_classVise_CCELoss(self.classes)

        # f= open("cityscapes_cali_testing.txt", "a")

        # f.write(f"Temp: {self.temp} \t CCE Loss: {overallLoss}\n")
        # f.close()

        pixAccCalibrated, mIoUCal, category_iou = self.cal_metric.get(
            return_category_iou=True)
        logging.info('Eval use time: {:.3f} second'.format(time.time() -
                                                           time_start))
        logging.info(
            'End validation pixAccCalibrated: {:.3f}, mIoUCal: {:.3f}'.format(
                pixAccCalibrated * 100, mIoUCal * 100))
        # pixAccUncal, mIoUunCal, category_iou = self.uncal_metric.get(return_category_iou=True)
        # logging.info('Eval use time: {:.3f} second'.format(time.time() - time_start))
        # logging.info('End validation pixAccUncal: {:.3f}, mIoUunCal: {:.3f}'.format(
        #         pixAccUncal * 100, mIoUunCal * 100))
        # pixAcc, mIoU, category_iou = self.metric.get(return_category_iou=True)
        # logging.info('Eval use time: {:.3f} second'.format(time.time() - time_start))
        # logging.info('End validation pixAcc: {:.3f}, mIoU: {:.3f}'.format(
        #         pixAcc * 100, mIoU * 100))

        f = open("foggy_zurich_cali_Acc.txt", "a")
        f.write("Temp: {} \t pixAcc: {:.3f}, mIoU: {:.3f} \n".format(
            self.temp, pixAccCalibrated * 100, mIoUCal * 100))
        f.write(f"Temp: {self.temp} \t CCE Loss: {overallLoss}\n")
        f.write("")
        f.close()
예제 #13
0
class Evaluator(object):
    def __init__(self, args):
        self.args = args
        self.device = torch.device(args.device)

        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(cfg.DATASET.MEAN, cfg.DATASET.STD),
        ])

        # dataset and dataloader
        val_dataset = get_segmentation_dataset(cfg.DATASET.NAME, split='val', mode='testval', transform=input_transform)
        val_sampler = make_data_sampler(val_dataset, False, args.distributed)

        #####################
        # BATCH SIZE is always 1

        val_batch_sampler = make_batch_data_sampler(val_sampler, images_per_batch=cfg.TEST.BATCH_SIZE, drop_last=False)
        self.val_loader = data.DataLoader(dataset=val_dataset,
                                          batch_sampler=val_batch_sampler,
                                          num_workers=cfg.DATASET.WORKERS,
                                          pin_memory=True)
        self.classes = val_dataset.classes

        ### Create network ###

        # Segmentron model
        # self.model = get_segmentation_model().to(self.device)

        # MMSeg model
        mmseg_config_file = "mmseg-configs/deeplabv3plus_r101-d8_512x512_80k_ade20k.py"
        mmseg_pretrained = "pretrained_weights/deeplabv3plus_r101-d8_512x512_80k_ade20k_20200615_014139-d5730af7.pth"
        self.model = init_segmentor(mmseg_config_file, mmseg_pretrained)

        self.model.to(self.device)
        self.metric = SegmentationMetric(val_dataset.num_class, args.distributed)

    def set_batch_norm_attr(self, named_modules, attr, value):
        for m in named_modules:
            if isinstance(m[1], nn.BatchNorm2d) or isinstance(m[1], nn.SyncBatchNorm):
                setattr(m[1], attr, value)

    def eval(self):
        self.metric.reset()
        self.model.eval()
        if self.args.distributed:
            model = self.model.module
        else:
            model = self.model

        logging.info("Using Val/Test img scale : {}".format(cfg.TEST.IMG_SCALE))
        logging.info("Start validation, Total sample: {:d}".format(len(self.val_loader)))
        import time
        time_start = time.time()
        pbar = tqdm(self.val_loader)
        for image, target, filename in pbar:
            image = image.to(self.device)
            target = target.to(self.device)

            assert image.shape[0] == 1, "Only batch-size 1 allowed when evaluating on test/val images"

            with torch.no_grad():
                output = mmseg_evaluate(model, image, target)

            self.metric.update(output, target)
            pixAcc, mIoU = self.metric.get()

            pbar.set_postfix_str("pixAcc: {:.3f}, mIoU: {:.3f}".format(pixAcc * 100, mIoU * 100))

        synchronize()
        pixAcc, mIoU, category_iou = self.metric.get(return_category_iou=True)
        logging.info('Eval use time: {:.3f} second'.format(time.time() - time_start))
        logging.info('End validation pixAcc: {:.3f}, mIoU: {:.3f}'.format(pixAcc * 100, mIoU * 100))
예제 #14
0
class Trainer(object):
    def __init__(self, args):
        self.args = args
        self.device = torch.device(args.device)
        self.use_fp16 = cfg.TRAIN.APEX

        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(cfg.DATASET.MEAN, cfg.DATASET.STD),
        ])
        # dataset and dataloader
        data_kwargs = {
            'transform': input_transform,
            'base_size': cfg.TRAIN.BASE_SIZE,
            'crop_size': cfg.TRAIN.CROP_SIZE
        }

        data_kwargs_testval = {
            'transform': input_transform,
            'base_size': cfg.TRAIN.BASE_SIZE,
            'crop_size': cfg.TEST.CROP_SIZE
        }

        train_dataset = get_segmentation_dataset(cfg.DATASET.NAME,
                                                 split='train',
                                                 mode='train',
                                                 **data_kwargs)
        val_dataset = get_segmentation_dataset(cfg.DATASET.NAME,
                                               split='val',
                                               mode='testval',
                                               **data_kwargs_testval)
        test_dataset = get_segmentation_dataset(cfg.DATASET.NAME,
                                                split='test',
                                                mode='testval',
                                                **data_kwargs_testval)

        self.classes = test_dataset.classes

        self.iters_per_epoch = len(train_dataset) // (args.num_gpus *
                                                      cfg.TRAIN.BATCH_SIZE)
        self.max_iters = cfg.TRAIN.EPOCHS * self.iters_per_epoch

        train_sampler = make_data_sampler(train_dataset,
                                          shuffle=True,
                                          distributed=args.distributed)
        train_batch_sampler = make_batch_data_sampler(train_sampler,
                                                      cfg.TRAIN.BATCH_SIZE,
                                                      self.max_iters,
                                                      drop_last=True)

        val_sampler = make_data_sampler(val_dataset, False, args.distributed)
        val_batch_sampler = make_batch_data_sampler(val_sampler,
                                                    cfg.TEST.BATCH_SIZE,
                                                    drop_last=False)

        test_sampler = make_data_sampler(test_dataset, False, args.distributed)
        test_batch_sampler = make_batch_data_sampler(test_sampler,
                                                     cfg.TEST.BATCH_SIZE,
                                                     drop_last=False)

        self.train_loader = data.DataLoader(dataset=train_dataset,
                                            batch_sampler=train_batch_sampler,
                                            num_workers=cfg.DATASET.WORKERS,
                                            pin_memory=True)
        self.val_loader = data.DataLoader(dataset=val_dataset,
                                          batch_sampler=val_batch_sampler,
                                          num_workers=cfg.DATASET.WORKERS,
                                          pin_memory=True)
        self.test_loader = data.DataLoader(dataset=test_dataset,
                                           batch_sampler=test_batch_sampler,
                                           num_workers=cfg.DATASET.WORKERS,
                                           pin_memory=True)

        # create network
        self.model = get_segmentation_model().to(self.device)

        # print params and flops
        if get_rank() == 0:
            try:
                show_flops_params(copy.deepcopy(self.model), args.device)
            except Exception as e:
                logging.warning('get flops and params error: {}'.format(e))
        if cfg.MODEL.BN_TYPE not in ['BN']:
            logging.info(
                'Batch norm type is {}, convert_sync_batchnorm is not effective'
                .format(cfg.MODEL.BN_TYPE))
        elif args.distributed and cfg.TRAIN.SYNC_BATCH_NORM:
            self.model = nn.SyncBatchNorm.convert_sync_batchnorm(self.model)
            logging.info('SyncBatchNorm is effective!')
        else:
            logging.info('Not use SyncBatchNorm!')
        # create criterion
        self.criterion = get_segmentation_loss(
            cfg.MODEL.MODEL_NAME,
            use_ohem=cfg.SOLVER.OHEM,
            aux=cfg.SOLVER.AUX,
            aux_weight=cfg.SOLVER.AUX_WEIGHT,
            ignore_index=cfg.DATASET.IGNORE_INDEX).to(self.device)
        # optimizer, for model just includes encoder, decoder(head and auxlayer).
        self.optimizer = get_optimizer(self.model)
        # apex
        if self.use_fp16:
            self.model, self.optimizer = apex.amp.initialize(self.model.cuda(),
                                                             self.optimizer,
                                                             opt_level="O1")
            logging.info('**** Initializing mixed precision done. ****')

        # lr scheduling
        self.lr_scheduler = get_scheduler(self.optimizer,
                                          max_iters=self.max_iters,
                                          iters_per_epoch=self.iters_per_epoch)
        # resume checkpoint if needed
        self.start_epoch = 0
        if args.resume and os.path.isfile(args.resume):
            name, ext = os.path.splitext(args.resume)
            assert ext == '.pkl' or '.pth', 'Sorry only .pth and .pkl files supported.'
            logging.info('Resuming training, loading {}...'.format(
                args.resume))
            resume_sate = torch.load(args.resume)
            self.model.load_state_dict(resume_sate['state_dict'])
            self.start_epoch = resume_sate['epoch']
            logging.info('resume train from epoch: {}'.format(
                self.start_epoch))
            if resume_sate['optimizer'] is not None and resume_sate[
                    'lr_scheduler'] is not None:
                logging.info(
                    'resume optimizer and lr scheduler from resume state..')
                self.optimizer.load_state_dict(resume_sate['optimizer'])
                self.lr_scheduler.load_state_dict(resume_sate['lr_scheduler'])

        if args.distributed:
            self.model = nn.parallel.DistributedDataParallel(
                self.model,
                device_ids=[args.local_rank],
                output_device=args.local_rank,
                find_unused_parameters=True)
        # evaluation metrics
        self.metric = SegmentationMetric(train_dataset.num_class,
                                         args.distributed)

    def train(self):
        self.save_to_disk = get_rank() == 0
        epochs, max_iters, iters_per_epoch = cfg.TRAIN.EPOCHS, self.max_iters, self.iters_per_epoch
        log_per_iters, val_per_iters = self.args.log_iter, self.args.val_epoch * self.iters_per_epoch

        start_time = time.time()
        logging.info(
            'Start training, Total Epochs: {:d} = Total Iterations {:d}'.
            format(epochs, max_iters))

        self.model.train()
        iteration = self.start_epoch * iters_per_epoch if self.start_epoch > 0 else 0
        for (images, targets, _) in self.train_loader:
            epoch = iteration // iters_per_epoch + 1
            iteration += 1

            images = images.to(self.device)
            targets = targets.to(self.device)

            outputs = self.model(images)
            loss_dict = self.criterion(outputs, targets)
            losses = sum(loss for loss in loss_dict.values())
            # reduce losses over all GPUs for logging purposes
            loss_dict_reduced = reduce_loss_dict(loss_dict)
            losses_reduced = sum(loss for loss in loss_dict_reduced.values())

            self.optimizer.zero_grad()
            if self.use_fp16:
                with apex.amp.scale_loss(losses,
                                         self.optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                losses.backward()
            self.optimizer.step()
            self.lr_scheduler.step()

            eta_seconds = ((time.time() - start_time) /
                           iteration) * (max_iters - iteration)
            eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
            if iteration % log_per_iters == 0 and self.save_to_disk:
                logging.info(
                    "Epoch: {:d}/{:d} || Iters: {:d}/{:d} || Lr: {:.6f} || "
                    "Loss: {:.4f} || Cost Time: {} || Estimated Time: {}".
                    format(
                        epoch, epochs, iteration % iters_per_epoch,
                        iters_per_epoch, self.optimizer.param_groups[0]['lr'],
                        losses_reduced.item(),
                        str(
                            datetime.timedelta(seconds=int(time.time() -
                                                           start_time))),
                        eta_string))

            if iteration % self.iters_per_epoch == 0 and self.save_to_disk:
                save_checkpoint(self.model,
                                epoch,
                                self.optimizer,
                                self.lr_scheduler,
                                is_best=False)

            if not self.args.skip_val and iteration % val_per_iters == 0:
                self.validation(epoch)
                self.test()
                self.model.train()

        total_training_time = time.time() - start_time
        total_training_str = str(
            datetime.timedelta(seconds=total_training_time))
        logging.info("Total training time: {} ({:.4f}s / it)".format(
            total_training_str, total_training_time / max_iters))

    def validation(self, epoch):
        self.metric.reset()
        if self.args.distributed:
            model = self.model.module
        else:
            model = self.model
        torch.cuda.empty_cache()
        model.eval()
        for i, (image, target, filename) in enumerate(self.val_loader):
            image = image.to(self.device)
            target = target.to(self.device)

            with torch.no_grad():
                if cfg.DATASET.MODE == 'val' or cfg.TEST.CROP_SIZE is None:
                    output = model(image)[0]
                else:
                    size = image.size()[2:]
                    assert cfg.TEST.CROP_SIZE[0] == size[0]
                    assert cfg.TEST.CROP_SIZE[1] == size[1]
                    output = model(image)[0]

            self.metric.update(output, target)
            pixAcc, mIoU, category_iou = self.metric.get(
                return_category_iou=True)
            logging.info(
                "[EVAL] Sample: {:d}, pixAcc: {:.3f}, mIoU: {:.3f}".format(
                    i + 1, pixAcc * 100, mIoU * 100))
        pixAcc, mIoU = self.metric.get()
        logging.info(
            "[EVAL END] Epoch: {:d}, pixAcc: {:.3f}, mIoU: {:.3f}".format(
                epoch, pixAcc * 100, mIoU * 100))
        synchronize()

    def test(self, vis=False):
        self.metric.reset()
        if self.args.distributed:
            model = self.model.module
        else:
            model = self.model
        torch.cuda.empty_cache()
        model.eval()
        for i, (image, target, filename) in enumerate(self.test_loader):
            image = image.to(self.device)
            target = target.to(self.device)

            with torch.no_grad():
                if cfg.DATASET.MODE == 'test' or cfg.TEST.CROP_SIZE is None:
                    output = model(image)[0]
                else:
                    size = image.size()[2:]
                    assert cfg.TEST.CROP_SIZE[0] == size[0]
                    assert cfg.TEST.CROP_SIZE[1] == size[1]
                    output = model(image)[0]

            if vis:
                save_gt = False
                if save_gt:
                    test_path = '/mnt/lustre/xieenze/xez_space/TransparentSeg/datasets/transparent/Trans10K_cls12/test/images'
                    save_path = 'workdirs/trans10kv2/gt_img'
                    if not os.path.exists(save_path):
                        os.makedirs(save_path)
                    gt_img = Image.open(os.path.join(test_path,
                                                     filename[0])).resize(
                                                         (512, 512))
                    gt_img.save(os.path.join(save_path, str(i) + '.png'))

                    gt_mask = target[0].data.cpu().numpy()
                    vis_gt = get_color_pallete(gt_mask, dataset='trans10kv2')
                    save_path = 'workdirs/trans10kv2/gt_mask'
                    if not os.path.exists(save_path):
                        os.makedirs(save_path)
                    vis_gt.save(os.path.join(save_path, str(i) + '.png'))
                else:
                    vis_pred = output[0].permute(
                        1, 2, 0).argmax(-1).data.cpu().numpy()
                    vis_pred = get_color_pallete(vis_pred,
                                                 dataset='trans10kv2')
                    save_path = os.path.join(cfg.TRAIN.MODEL_SAVE_DIR, 'vis')
                    if not os.path.exists(save_path):
                        os.makedirs(save_path)
                    vis_pred.save(os.path.join(save_path, str(i) + '.png'))
                print("[VIS TEST] Sample: {:d}".format(i + 1))
                continue

            self.metric.update(output, target)
            pixAcc, mIoU, category_iou = self.metric.get(
                return_category_iou=True)
            logging.info(
                "[TEST] Sample: {:d}, pixAcc: {:.3f}, mIoU: {:.3f}".format(
                    i + 1, pixAcc * 100, mIoU * 100))

        synchronize()
        pixAcc, mIoU, category_iou = self.metric.get(return_category_iou=True)
        logging.info("[TEST END]  pixAcc: {:.3f}, mIoU: {:.3f}".format(
            pixAcc * 100, mIoU * 100))

        headers = ['class id', 'class name', 'iou']
        table = []
        for i, cls_name in enumerate(self.classes):
            table.append([cls_name, category_iou[i]])
        logging.info('Category iou: \n {}'.format(
            tabulate(table,
                     headers,
                     tablefmt='grid',
                     showindex="always",
                     numalign='center',
                     stralign='center')))