Exemple #1
0
class Evaluator(object):
    def __init__(self, args):

        self.args = args
        self.device = torch.device(args.device)

        self.n_bins = 15
        self.ece_folder = "eceData"
        # self.postfix="foggy_conv13_CityScapes_GPU"
        self.postfix = "foggy_zurich_conv13"
        # self.postfix="Foggy_1_conv13_PascalVOC_GPU"
        self.temp = 1.5
        # self.useCRF=False
        self.useCRF = True

        self.ece_criterion = metrics.IterativeECELoss()
        self.ece_criterion.make_bins(n_bins=self.n_bins)

        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(cfg.DATASET.MEAN, cfg.DATASET.STD),
        ])

        # dataset and dataloader
        val_dataset = get_segmentation_dataset(cfg.DATASET.NAME,
                                               split='val',
                                               mode='testval',
                                               transform=input_transform)
        val_sampler = make_data_sampler(val_dataset, False, args.distributed)
        val_batch_sampler = make_batch_data_sampler(
            val_sampler, images_per_batch=cfg.TEST.BATCH_SIZE, drop_last=False)
        self.val_loader = data.DataLoader(dataset=val_dataset,
                                          batch_sampler=val_batch_sampler,
                                          num_workers=cfg.DATASET.WORKERS,
                                          pin_memory=True)

        self.dataset = val_dataset
        self.classes = val_dataset.classes
        print(args.distributed)
        self.metric = SegmentationMetric(val_dataset.num_class,
                                         args.distributed)

        self.model = get_segmentation_model().to(self.device)

        if hasattr(self.model, 'encoder') and hasattr(self.model.encoder, 'named_modules') and \
            cfg.MODEL.BN_EPS_FOR_ENCODER:
            logging.info('set bn custom eps for bn in encoder: {}'.format(
                cfg.MODEL.BN_EPS_FOR_ENCODER))
            self.set_batch_norm_attr(self.model.encoder.named_modules(), 'eps',
                                     cfg.MODEL.BN_EPS_FOR_ENCODER)

        if args.distributed:
            self.model = nn.parallel.DistributedDataParallel(
                self.model,
                device_ids=[args.local_rank],
                output_device=args.local_rank,
                find_unused_parameters=True)

        self.model.to(self.device)

    def set_batch_norm_attr(self, named_modules, attr, value):
        for m in named_modules:
            if isinstance(m[1], nn.BatchNorm2d) or isinstance(
                    m[1], nn.SyncBatchNorm):
                setattr(m[1], attr, value)

    def eceOperations(self, bin_total, bin_total_correct, bin_conf_total):
        eceLoss = self.ece_criterion.get_interative_loss(
            bin_total, bin_total_correct, bin_conf_total)
        print('ECE with probabilties %f' % (eceLoss))

        saveDir = os.path.join(self.ece_folder, self.postfix)
        makedirs(saveDir)

        file = open(os.path.join(saveDir, "Results.txt"), "a")
        file.write(
            f"{self.postfix}_temp={self.temp}\t\t\t ECE Loss: {eceLoss}\n")

        plot_folder = os.path.join(saveDir, "plots")
        makedirs(plot_folder)

        rel_diagram = visualization.ReliabilityDiagramIterative()
        plt_test_2 = rel_diagram.plot(bin_total,
                                      bin_total_correct,
                                      bin_conf_total,
                                      title="Reliability Diagram")
        plt_test_2.savefig(os.path.join(plot_folder,
                                        f'rel_diagram_temp={self.temp}.png'),
                           bbox_inches='tight')

    def eval(self):
        self.metric.reset()
        self.model.eval()
        model = self.model

        logging.info("Start validation, Total sample: {:d}".format(
            len(self.val_loader)))
        import time
        time_start = time.time()
        # if(not self.useCRF):
        bin_total = []
        bin_total_correct = []
        bin_conf_total = []
        for (image, target, filename) in tqdm(self.val_loader):
            image = image.to(self.device)
            target = target.to(self.device)

            # print(image.shape)
            with torch.no_grad():
                output = model.evaluate(image)
                output /= self.temp
                output_for_ece = output.clone()

                # if use CRF
                if (self.useCRF):
                    filename = filename[0]
                    raw_image = cv2.imread(filename, cv2.IMREAD_COLOR).astype(
                        np.float32).transpose(2, 0, 1)
                    raw_image = torch.from_numpy(raw_image).to(self.device)
                    raw_image = raw_image.unsqueeze(dim=0)
                    crf = GaussCRF(conf=get_default_conf(),
                                   shape=image.shape[2:],
                                   nclasses=len(self.classes),
                                   use_gpu=True)
                    crf = crf.to(self.device)
                    # print(image.shape,raw_image.shape)
                    assert image.shape == raw_image.shape
                    output = crf.forward(output, raw_image)

            # ECE Stuff
            conf = np.max(output_for_ece.softmax(dim=1).cpu().numpy(), axis=1)
            label = torch.argmax(output_for_ece, dim=1).cpu().numpy()
            # print(conf.shape,label.shape,target.shape)
            bin_total_current, bin_total_correct_current, bin_conf_total_current = self.ece_criterion.get_collective_bins(
                conf, label,
                target.cpu().numpy())
            # import pdb; pdb.set_trace()
            bin_total.append(bin_total_current)
            bin_total_correct.append(bin_total_correct_current)
            bin_conf_total.append(bin_conf_total_current)

            # Accuracy Stuff
            self.metric.update(output, target)
            pixAcc, mIoU = self.metric.get()

        # ECE stuff
        # if(not self.useCRF):
        self.eceOperations(bin_total, bin_total_correct, bin_conf_total)

        # Accuracy stuff
        pixAcc, mIoU, category_iou = self.metric.get(return_category_iou=True)
        logging.info('Eval use time: {:.3f} second'.format(time.time() -
                                                           time_start))
        # file=open("foggy_1_conv13_VOC.txt","a")
        file = open(f"{self.postfix}.txt", "a")
        file.write("Temp={} + crf\n".format(self.temp))
        file.write('End validation pixAcc: {:.3f}, mIoU: {:.3f}'.format(
            pixAcc * 100, mIoU * 100))

        file.write("\n\n")
        file.close()

        logging.info('End validation pixAcc: {:.3f}, mIoU: {:.3f}'.format(
            pixAcc * 100, mIoU * 100))
Exemple #2
0
class Evaluator(object):
    def __init__(self, args):
        self.args = args
        self.device = torch.device(args.device)

        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(cfg.DATASET.MEAN, cfg.DATASET.STD),
        ])

        # dataset and dataloader
        val_dataset = get_segmentation_dataset(cfg.DATASET.NAME,
                                               split='val',
                                               mode='testval',
                                               transform=input_transform)
        val_sampler = make_data_sampler(val_dataset, False, args.distributed)
        val_batch_sampler = make_batch_data_sampler(
            val_sampler, images_per_batch=cfg.TEST.BATCH_SIZE, drop_last=False)
        self.val_loader = data.DataLoader(dataset=val_dataset,
                                          batch_sampler=val_batch_sampler,
                                          num_workers=cfg.DATASET.WORKERS,
                                          pin_memory=True)
        self.classes = val_dataset.classes
        # create network
        self.model = get_segmentation_model().to(self.device)

        if hasattr(self.model, 'encoder') and cfg.MODEL.BN_EPS_FOR_ENCODER:
            logging.info('set bn custom eps for bn in encoder: {}'.format(
                cfg.MODEL.BN_EPS_FOR_ENCODER))
            self.set_batch_norm_attr(self.model.encoder.named_modules(), 'eps',
                                     cfg.MODEL.BN_EPS_FOR_ENCODER)

        if args.distributed:
            self.model = nn.parallel.DistributedDataParallel(
                self.model,
                device_ids=[args.local_rank],
                output_device=args.local_rank,
                find_unused_parameters=True)
        self.model.to(self.device)

        self.metric = SegmentationMetric(val_dataset.num_class,
                                         args.distributed)

    def set_batch_norm_attr(self, named_modules, attr, value):
        for m in named_modules:
            if isinstance(m[1], nn.BatchNorm2d) or isinstance(
                    m[1], nn.SyncBatchNorm):
                setattr(m[1], attr, value)

    def eval(self):
        self.metric.reset()
        self.model.eval()
        if self.args.distributed:
            model = self.model.module
        else:
            model = self.model

        logging.info("Start validation, Total sample: {:d}".format(
            len(self.val_loader)))
        import time
        time_start = time.time()
        for i, (image, target, filename) in enumerate(self.val_loader):
            image = image.to(self.device)
            target = target.to(self.device)

            with torch.no_grad():
                output = model.evaluate(image)

            self.metric.update(output, target)
            pixAcc, mIoU = self.metric.get()
            logging.info(
                "Sample: {:d}, validation pixAcc: {:.3f}, mIoU: {:.3f}".format(
                    i + 1, pixAcc * 100, mIoU * 100))

        synchronize()
        pixAcc, mIoU, category_iou = self.metric.get(return_category_iou=True)
        logging.info('Eval use time: {:.3f} second'.format(time.time() -
                                                           time_start))
        logging.info('End validation pixAcc: {:.3f}, mIoU: {:.3f}'.format(
            pixAcc * 100, mIoU * 100))

        headers = ['class id', 'class name', 'iou']
        table = []
        for i, cls_name in enumerate(self.classes):
            table.append([cls_name, category_iou[i]])
        logging.info('Category iou: \n {}'.format(
            tabulate(table,
                     headers,
                     tablefmt='grid',
                     showindex="always",
                     numalign='center',
                     stralign='center')))
Exemple #3
0
class Trainer(object):
    def __init__(self, args):
        self.args = args
        self.device = torch.device(args.device)

        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(cfg.DATASET.MEAN, cfg.DATASET.STD),
        ])
        # dataset and dataloader
        data_kwargs = {'transform': input_transform, 'base_size': cfg.TRAIN.BASE_SIZE,
                       'crop_size': cfg.TRAIN.CROP_SIZE}
        train_dataset = get_segmentation_dataset(cfg.DATASET.NAME, split='train', mode='train', **data_kwargs)
        val_dataset = get_segmentation_dataset(cfg.DATASET.NAME, split='val', mode=cfg.DATASET.MODE, **data_kwargs)
        self.iters_per_epoch = len(train_dataset) // (args.num_gpus * cfg.TRAIN.BATCH_SIZE)
        self.max_iters = cfg.TRAIN.EPOCHS * self.iters_per_epoch

        train_sampler = make_data_sampler(train_dataset, shuffle=True, distributed=args.distributed)
        train_batch_sampler = make_batch_data_sampler(train_sampler, cfg.TRAIN.BATCH_SIZE, self.max_iters, drop_last=True)
        val_sampler = make_data_sampler(val_dataset, False, args.distributed)
        val_batch_sampler = make_batch_data_sampler(val_sampler, cfg.TEST.BATCH_SIZE, drop_last=False)

        self.train_loader = data.DataLoader(dataset=train_dataset,
                                            batch_sampler=train_batch_sampler,
                                            num_workers=cfg.DATASET.WORKERS,
                                            pin_memory=True)
        self.val_loader = data.DataLoader(dataset=val_dataset,
                                          batch_sampler=val_batch_sampler,
                                          num_workers=cfg.DATASET.WORKERS,
                                          pin_memory=True)

        # create network
        self.model = get_segmentation_model().to(self.device)
        
        # print params and flops
        if get_rank() == 0:
            try:
                show_flops_params(copy.deepcopy(self.model), args.device)
            except Exception as e:
                logging.warning('get flops and params error: {}'.format(e))

        if cfg.MODEL.BN_TYPE not in ['BN']:
            logging.info('Batch norm type is {}, convert_sync_batchnorm is not effective'.format(cfg.MODEL.BN_TYPE))
        elif args.distributed and cfg.TRAIN.SYNC_BATCH_NORM:
            self.model = nn.SyncBatchNorm.convert_sync_batchnorm(self.model)
            logging.info('SyncBatchNorm is effective!')
        else:
            logging.info('Not use SyncBatchNorm!')

        # create criterion
        self.criterion = get_segmentation_loss(cfg.MODEL.MODEL_NAME, use_ohem=cfg.SOLVER.OHEM,
                                               aux=cfg.SOLVER.AUX, aux_weight=cfg.SOLVER.AUX_WEIGHT,
                                               ignore_index=cfg.DATASET.IGNORE_INDEX).to(self.device)

        # optimizer, for model just includes encoder, decoder(head and auxlayer).
        self.optimizer = get_optimizer(self.model)

        # lr scheduling
        self.lr_scheduler = get_scheduler(self.optimizer, max_iters=self.max_iters,
                                          iters_per_epoch=self.iters_per_epoch)

        # resume checkpoint if needed
        self.start_epoch = 0
        if args.resume and os.path.isfile(args.resume):
            name, ext = os.path.splitext(args.resume)
            assert ext == '.pkl' or '.pth', 'Sorry only .pth and .pkl files supported.'
            logging.info('Resuming training, loading {}...'.format(args.resume))
            resume_state = torch.load(args.resume)
            # self.model.load_state_dict(resume_state['state_dict'])
            self.model.load_state_dict(resume_state)
            # self.start_epoch = resume_state['epoch']
            self.start_epoch = 0
            logging.info('resume train from epoch: {}'.format(self.start_epoch))
            # if resume_state['optimizer'] is not None and resume_state['lr_scheduler'] is not None:
            #     logging.info('resume optimizer and lr scheduler from resume state..')
            #     self.optimizer.load_state_dict(resume_sate['optimizer'])
            #     self.lr_scheduler.load_state_dict(resume_sate['lr_scheduler'])

        if args.distributed:
            self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[args.local_rank],
                                                             output_device=args.local_rank,
                                                             find_unused_parameters=True)

        # evaluation metrics
        self.metric = SegmentationMetric(train_dataset.num_class, args.distributed)
        self.best_pred = 0.0


    def train(self):
        self.save_to_disk = get_rank() == 0
        epochs, max_iters, iters_per_epoch = cfg.TRAIN.EPOCHS, self.max_iters, self.iters_per_epoch
        log_per_iters, val_per_iters = self.args.log_iter, self.args.val_epoch * self.iters_per_epoch

        start_time = time.time()
        logging.info('Start training, Total Epochs: {:d} = Total Iterations {:d}'.format(epochs, max_iters))

        self.model.train()
        iteration = self.start_epoch * iters_per_epoch if self.start_epoch > 0 else 0
        for (images, targets, _) in self.train_loader:
            epoch = iteration // iters_per_epoch + 1
            iteration += 1

            images = images.to(self.device)
            targets = targets.to(self.device)

            outputs = self.model(images)
            loss_dict = self.criterion(outputs, targets)

            losses = sum(loss for loss in loss_dict.values())

            # reduce losses over all GPUs for logging purposes
            loss_dict_reduced = reduce_loss_dict(loss_dict)
            losses_reduced = sum(loss for loss in loss_dict_reduced.values())

            self.optimizer.zero_grad()
            losses.backward()
            self.optimizer.step()
            self.lr_scheduler.step()

            eta_seconds = ((time.time() - start_time) / iteration) * (max_iters - iteration)
            eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))

            if iteration % log_per_iters == 0 and self.save_to_disk:
                logging.info(
                    "Epoch: {:d}/{:d} || Iters: {:d}/{:d} || Lr: {:.6f} || "
                    "Loss: {:.4f} || Cost Time: {} || Estimated Time: {}".format(
                        epoch, epochs, iteration % iters_per_epoch, iters_per_epoch,
                        self.optimizer.param_groups[0]['lr'], losses_reduced.item(),
                        str(datetime.timedelta(seconds=int(time.time() - start_time))),
                        eta_string))

            if iteration % self.iters_per_epoch == 0 and self.save_to_disk:
                save_checkpoint(self.model, epoch, self.optimizer, self.lr_scheduler, is_best=False)

            if not self.args.skip_val and iteration % val_per_iters == 0:
                self.validation(epoch)
                self.model.train()

        total_training_time = time.time() - start_time
        total_training_str = str(datetime.timedelta(seconds=total_training_time))
        logging.info(
            "Total training time: {} ({:.4f}s / it)".format(
                total_training_str, total_training_time / max_iters))

    def validation(self, epoch):
        self.metric.reset()
        if self.args.distributed:
            model = self.model.module
        else:
            model = self.model
        torch.cuda.empty_cache()
        model.eval()
        for i, (image, target, filename) in enumerate(self.val_loader):
            image = image.to(self.device)
            target = target.to(self.device)

            with torch.no_grad():
                if cfg.DATASET.MODE == 'val' or cfg.TEST.CROP_SIZE is None:
                    output = model(image)[0]
                else:
                    size = image.size()[2:]
                    pad_height = cfg.TEST.CROP_SIZE[0] - size[0]
                    pad_width = cfg.TEST.CROP_SIZE[1] - size[1]
                    image = F.pad(image, (0, pad_height, 0, pad_width))
                    output = model(image)[0]
                    output = output[..., :size[0], :size[1]]

            self.metric.update(output, target)
            pixAcc, mIoU = self.metric.get()
            logging.info("[EVAL] Sample: {:d}, pixAcc: {:.3f}, mIoU: {:.3f}".format(i + 1, pixAcc * 100, mIoU * 100))
        pixAcc, mIoU = self.metric.get()
        logging.info("[EVAL END] Epoch: {:d}, pixAcc: {:.3f}, mIoU: {:.3f}".format(epoch, pixAcc * 100, mIoU * 100))
        synchronize()
        if self.best_pred < mIoU and self.save_to_disk:
            self.best_pred = mIoU
            logging.info('Epoch {} is the best model, best pixAcc: {:.3f}, mIoU: {:.3f}, save the model..'.format(epoch, pixAcc * 100, mIoU * 100))
            save_checkpoint(model, epoch, is_best=True)
Exemple #4
0
class Evaluator(object):
    def __init__(self, args):

        self.args = args
        self.device = torch.device(args.device)

        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(cfg.DATASET.MEAN, cfg.DATASET.STD),
        ])
        self.lr = 2.5
        self.prefix = f"2_boxes_3_7={self.lr}"
        # self.prefix = f"overfit__count_toy_experiment_3class_7_2_1_conf_loss=total_xavier_weights_xavier_bias_lr={self.lr}"
        self.writer = SummaryWriter(log_dir=f"cce_toy_logs/{self.prefix}")
        # self.writer = SummaryWriter(log_dir= f"cce_cityscapes_logs/{self.prefix}")
        # dataset and dataloader
        val_dataset = get_segmentation_dataset(cfg.DATASET.NAME,
                                               split='val',
                                               mode='testval',
                                               transform=input_transform)
        # val_sampler = make_data_sampler(val_dataset, False, args.distributed)
        self.val_loader = data.DataLoader(dataset=val_dataset,
                                          shuffle=True,
                                          batch_size=cfg.TEST.BATCH_SIZE,
                                          drop_last=True,
                                          num_workers=cfg.DATASET.WORKERS,
                                          pin_memory=True)

        self.dataset = val_dataset
        self.classes = val_dataset.classes
        self.metric = SegmentationMetric(val_dataset.num_class,
                                         args.distributed)

        # self.model = get_segmentation_model().to(self.device)

        # if hasattr(self.model, 'encoder') and hasattr(self.model.encoder, 'named_modules') and \
        #     cfg.MODEL.BN_EPS_FOR_ENCODER:
        #     logging.info('set bn custom eps for bn in encoder: {}'.format(cfg.MODEL.BN_EPS_FOR_ENCODER))
        #     self.set_batch_norm_attr(self.model.encoder.named_modules(), 'eps', cfg.MODEL.BN_EPS_FOR_ENCODER)

        # if args.distributed:
        #     self.model = nn.parallel.DistributedDataParallel(self.model,
        #         device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True)

        # self.model.to(self.device)

    def set_batch_norm_attr(self, named_modules, attr, value):
        for m in named_modules:
            if isinstance(m[1], nn.BatchNorm2d) or isinstance(
                    m[1], nn.SyncBatchNorm):
                setattr(m[1], attr, value)

    def eval(self):
        self.metric.reset()
        print(f"Length of classes: {len(self.classes)}")
        temp_weights = torch.eye(len(self.classes), device="cuda")
        torch.nn.init.xavier_uniform_(temp_weights, gain=1.0)
        print(temp_weights)
        temp_weights.requires_grad = True
        # temp_weights.requires_grad= True
        temp_bias = torch.zeros(len(self.classes), device="cuda")
        # torch.nn.init.xavier_uniform_(temp_bias, gain=1.0)
        temp_bias.requires_grad = True
        # temp_weights = torch.rand(len(self.classes), len(self.classes), device="cuda", requires_grad=True)
        # temp_bias = torch.rand(len(self.classes), device="cuda", requires_grad=True)

        logging.info(
            "Start training of temprature weights, Total sample: {:d}".format(
                len(self.val_loader)))

        cce_criterion = CCELoss(len(self.classes)).to(self.device)
        cross_criterion = torch.nn.CrossEntropyLoss(ignore_index=-1)
        optimizer = torch.optim.SGD([temp_weights, temp_bias], lr=self.lr)
        import time
        time_start = time.time()
        num_epochs = 300
        for epoch in range(num_epochs):
            eceEvaluator_perimage = perimageCCE(n_classes=len(self.classes))
            epoch_loss_cce_total = 0
            epoch_loss_cross_entropy_total = 0
            epoch_loss_total = 0
            for i, (images, targets, filenames) in enumerate(self.val_loader):
                # import pdb; pdb.set_trace()
                optimizer.zero_grad()

                images = images.to(self.device)
                targets = targets.to(self.device)

                # print(image.shape)
                with torch.no_grad():
                    # outputs = model.evaluate(images)

                    # outputs = torch.rand(1,3,300,400)
                    outputs = torch.ones(1, 2, 300, 400) * (torch.Tensor(
                        [0.3, 0.7]).reshape(1, -1, 1, 1))
                    # outputs = torch.ones(1,4,300,400)*(torch.Tensor([0.5,0.25,0.15, 0.1]).reshape(1,-1,1,1))
                    outputs = outputs.cuda()
                    outputs[0, 0, :, :200] = 0.7
                    outputs[0, 1, :, 200:] = 0.3

                    # outputs = torch.ones(1,3,300,400)*(torch.Tensor([0.7,0.2,0.1]).reshape(1,-1,1,1))
                    # # outputs = torch.ones(1,4,300,400)*(torch.Tensor([0.5,0.25,0.15, 0.1]).reshape(1,-1,1,1))
                    # outputs = outputs.cuda()
                    # outputs[0,0,100:200, 50:150] = 0.1
                    # outputs[0,0,100:150, 250:300] = 0.2
                    # outputs[0,1,100:200, 50:150] = 0.7
                    # outputs[0,1,100:150, 250:300] = 0.1
                    # outputs[0,2,100:200, 50:150] = 0.2
                    # outputs[0,2,100:150, 250:300] = 0.7

                    # Converting back to logits
                    outputs = torch.log(outputs)

                outputs = outputs.permute(0, 2, 3, 1).contiguous()
                outputs = torch.matmul(outputs, temp_weights)
                outputs = outputs + temp_bias

                outputs = outputs.permute(0, 3, 1, 2).contiguous()

                # Add image stuff
                save_imgs = torch.softmax(outputs, dim=1).squeeze(0)
                # analyse(outputs = save_imgs.unsqueeze(0))
                # accuracy(outputs = outputs)
                for class_no, class_distri in enumerate(save_imgs):
                    plt.clf()
                    class_distri[0][0] = 0
                    class_distri[0][1] = 1

                    im = plt.imshow(class_distri.detach().cpu().numpy(),
                                    cmap="Greens")
                    plt.colorbar(im)
                    plt.savefig("temp_files/temp.jpg")
                    plt.clf()
                    import cv2
                    img_dif = cv2.imread("temp_files/temp.jpg")

                    self.writer.add_image(f"Class_{class_no}",
                                          img_dif,
                                          epoch,
                                          dataformats="HWC")

                loss_cce = cce_criterion.forward(outputs, targets)
                loss_cross_entropy = cross_criterion.forward(outputs, targets)

                alpha = 0
                total_loss = loss_cce + alpha * loss_cross_entropy

                epoch_loss_cce_total += loss_cce.item()
                epoch_loss_cross_entropy_total += loss_cross_entropy.item()
                epoch_loss_total += total_loss.item()

                total_loss.backward()
                optimizer.step()

                with torch.no_grad():
                    for output, target in zip(outputs, targets.detach()):
                        # older ece requires softmax and size output=[class,w,h] target=[w,h]
                        eceEvaluator_perimage.update(output.softmax(dim=0),
                                                     target)
                # print(outputs.shape)
                # print(eceEvaluator_perimage.get_overall_CCELoss())
                print(
                    f"batch :{i+1}/{len(self.val_loader)}" +
                    "loss cce : {:.5f} | loss cls : {:.5f} | loss tot : {:.5f}"
                    .format(loss_cce, loss_cross_entropy, total_loss))

            print(temp_weights)
            print(temp_bias)
            epoch_loss_cce_total /= len(self.val_loader)
            epoch_loss_cross_entropy_total /= len(self.val_loader)
            epoch_loss_total /= len(self.val_loader)

            count_table_image, _ = eceEvaluator_perimage.get_count_table_img(
                self.classes)
            cce_table_image, dif_map = eceEvaluator_perimage.get_perc_table_img(
                self.classes)
            self.writer.add_image("CCE_table",
                                  cce_table_image,
                                  epoch,
                                  dataformats="HWC")
            self.writer.add_image("Count table",
                                  count_table_image,
                                  epoch,
                                  dataformats="HWC")
            self.writer.add_image("DifMap", dif_map, epoch, dataformats="HWC")

            self.writer.add_scalar(f"Cross EntropyLoss_LR",
                                   epoch_loss_cross_entropy_total, epoch)
            self.writer.add_scalar(f"CCELoss_LR", epoch_loss_cce_total, epoch)
            self.writer.add_scalar(f"Total Loss_LR", epoch_loss_total, epoch)
            self.writer.add_histogram("Weights", temp_weights, epoch)
            self.writer.add_histogram("Bias", temp_bias, epoch)
            # output = output/temp_weights
            # print(output.shape)
            # print(temp_weights, temp_bias)

            if epoch > 0 and epoch % 10 == 0:
                print("saving weights.")
                np.save("weights/toy/wt_{}_{}.npy".format(epoch, self.prefix),
                        temp_weights.cpu().detach().numpy())
                np.save("weights/toy/b{}_{}.npy".format(epoch, self.prefix),
                        temp_bias.cpu().detach().numpy())

            # print("epoch {} : loss {:.5f}".format(epoch, epoch_loss))
            # import pdb; pdb.set_trace()

        self.writer.close()
Exemple #5
0
class Evaluator(object):
    def __init__(self, args, temp=1):

        self.args = args
        self.device = torch.device(args.device)

        self.args = args
        self.device = torch.device(args.device)

        self.n_bins = 10
        self.ece_folder = "experiments/classCali/eceData"
        # self.postfix = "Conv13_PascalVOC_GPU"
        # self.postfix = "Min_Foggy_1_conv13_PascalVOC_GPU"
        # self.postfix = "Foggy_1_conv13_PascalVOC_GPU"
        # self.postfix = "FoggyCityscapes_conv13_exp"
        self.postfix = "foggy_zurich_conv13"
        self.temp = temp
        print("Current temp being used : {}".format(self.temp))
        self.showProbMaps = False
        # self.useCRF=False
        self.useCRF = True

        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(cfg.DATASET.MEAN, cfg.DATASET.STD),
        ])

        # dataset and dataloader
        val_dataset = get_segmentation_dataset(cfg.DATASET.NAME,
                                               split='val',
                                               mode='testval',
                                               transform=input_transform)
        val_sampler = make_data_sampler(val_dataset, False, args.distributed)
        val_batch_sampler = make_batch_data_sampler(
            val_sampler, images_per_batch=cfg.TEST.BATCH_SIZE, drop_last=False)
        self.val_loader = data.DataLoader(dataset=val_dataset,
                                          batch_sampler=val_batch_sampler,
                                          num_workers=cfg.DATASET.WORKERS,
                                          pin_memory=True)

        self.dataset = val_dataset
        self.classes = val_dataset.classes
        self.metric = SegmentationMetric(val_dataset.num_class,
                                         args.distributed)
        self.uncal_metric = SegmentationMetric(val_dataset.num_class,
                                               args.distributed)
        self.cal_metric = SegmentationMetric(val_dataset.num_class,
                                             args.distributed)

        self.model = get_segmentation_model().to(self.device)

        if hasattr(self.model, 'encoder') and hasattr(self.model.encoder, 'named_modules') and \
            cfg.MODEL.BN_EPS_FOR_ENCODER:
            logging.info('set bn custom eps for bn in encoder: {}'.format(
                cfg.MODEL.BN_EPS_FOR_ENCODER))
            self.set_batch_norm_attr(self.model.encoder.named_modules(), 'eps',
                                     cfg.MODEL.BN_EPS_FOR_ENCODER)

        if args.distributed:
            self.model = nn.parallel.DistributedDataParallel(
                self.model,
                device_ids=[args.local_rank],
                output_device=args.local_rank,
                find_unused_parameters=True)

        self.model.to(self.device)

    def set_batch_norm_attr(self, named_modules, attr, value):
        for m in named_modules:
            if isinstance(m[1], nn.BatchNorm2d) or isinstance(
                    m[1], nn.SyncBatchNorm):
                setattr(m[1], attr, value)

    def eval(self):
        self.metric.reset()
        self.model.eval()
        model = self.model

        logging.info("Start validation, Total sample: {:d}".format(
            len(self.val_loader)))
        import time
        eceEvaluator = CCELoss(n_classes=len(self.classes))
        time_start = time.time()
        for (image, target, filename) in tqdm(self.val_loader):
            image = image.to(self.device)
            target = target.to(self.device)

            # print(image.shape)
            with torch.no_grad():
                output = model.evaluate(image)
                # output = torch.softmax(output, dim=1)
                # output =output*0 + 1/len(self.classes)

                output_uncal = output.clone()

                # output_cal = torch.log(torch.softmax(output, dim=1))
                output_cal = output / self.temp

                # if use CRF
                if (self.useCRF):
                    filename = filename[0]
                    # print(filename)
                    raw_image = cv2.imread(filename, cv2.IMREAD_COLOR).astype(
                        np.float32).transpose(2, 0, 1)
                    raw_image = torch.from_numpy(raw_image).to(self.device)
                    raw_image = raw_image.unsqueeze(dim=0)
                    crf = GaussCRF(conf=get_default_conf(),
                                   shape=image.shape[2:],
                                   nclasses=len(self.classes),
                                   use_gpu=True)
                    crf = crf.to(self.device)
                    assert image.shape == raw_image.shape

                    output_cal_crf = crf.forward(output_cal, raw_image)
                    output_uncal_crf = crf.forward(output_uncal, raw_image)
                    comparisionFolder = "experiments/classCali/comparisionImages"
                    saveFolder = os.path.join(
                        comparisionFolder, self.postfix + f"_temp={self.temp}")
                    makedirs(saveFolder)
                    saveName = os.path.join(saveFolder,
                                            os.path.basename(filename))

                    # image_pil = Image.open(filename).convert('RGB')
                    # after_crf_uncal_mask = get_color_pallete(torch.argmax(output_uncal_crf, dim=1).squeeze(0).cpu().numpy(), cfg.DATASET.NAME)
                    # after_crf_cal_mask = get_color_pallete(torch.argmax(output_cal_crf, dim=1).squeeze(0).cpu().numpy(), cfg.DATASET.NAME)
                    # difference_map_mask = get_color_pallete((torch.argmax(output_uncal_crf, dim=1).squeeze(0).cpu().numpy() != torch.argmax(output_cal_crf, dim=1).squeeze(0).cpu().numpy())*19, cfg.DATASET.NAME)
                    # gt_mask = get_color_pallete(target.squeeze(0).cpu().numpy(), cfg.DATASET.NAME)

                    # # Concatenating horizontally [out_post_crf,out_pre_crf, rgb]
                    # dst = Image.new('RGB', (5*gt_mask.width+12, gt_mask.height), color="white")
                    # # dst = Image.new('RGB', (4*gt_mask.width+9, gt_mask.height), color="white")
                    # dst.paste(after_crf_uncal_mask, (0, 0))
                    # dst.paste(after_crf_cal_mask, (gt_mask.width+3, 0))
                    # dst.paste(gt_mask, (2*gt_mask.width+6, 0))
                    # dst.paste(image_pil, (3*gt_mask.width+9, 0))
                    # dst.paste(difference_map_mask, (4*gt_mask.width+ 12, 0))
                    # dst.save(saveName)

                # import pdb; pdb.set_trace()

                eceEvaluator.update(
                    output_cal.softmax(dim=1).squeeze(0), target.squeeze(0))
                # print(eceEvaluator.get_perc_table())

            if (self.useCRF):
                self.cal_metric.update(output_cal_crf, target)
                self.uncal_metric.update(output_uncal_crf, target)
            else:
                self.metric.update(output_cal, target)
                pixAcc, mIoU = self.metric.get()

        eceEvaluator.get_perc_table(self.classes)
        overallLoss = eceEvaluator.get_overall_CCELoss()
        eceEvaluator.get_classVise_CCELoss(self.classes)

        # f= open("cityscapes_cali_testing.txt", "a")

        # f.write(f"Temp: {self.temp} \t CCE Loss: {overallLoss}\n")
        # f.close()

        pixAccCalibrated, mIoUCal, category_iou = self.cal_metric.get(
            return_category_iou=True)
        logging.info('Eval use time: {:.3f} second'.format(time.time() -
                                                           time_start))
        logging.info(
            'End validation pixAccCalibrated: {:.3f}, mIoUCal: {:.3f}'.format(
                pixAccCalibrated * 100, mIoUCal * 100))
        # pixAccUncal, mIoUunCal, category_iou = self.uncal_metric.get(return_category_iou=True)
        # logging.info('Eval use time: {:.3f} second'.format(time.time() - time_start))
        # logging.info('End validation pixAccUncal: {:.3f}, mIoUunCal: {:.3f}'.format(
        #         pixAccUncal * 100, mIoUunCal * 100))
        # pixAcc, mIoU, category_iou = self.metric.get(return_category_iou=True)
        # logging.info('Eval use time: {:.3f} second'.format(time.time() - time_start))
        # logging.info('End validation pixAcc: {:.3f}, mIoU: {:.3f}'.format(
        #         pixAcc * 100, mIoU * 100))

        f = open("foggy_zurich_cali_Acc.txt", "a")
        f.write("Temp: {} \t pixAcc: {:.3f}, mIoU: {:.3f} \n".format(
            self.temp, pixAccCalibrated * 100, mIoUCal * 100))
        f.write(f"Temp: {self.temp} \t CCE Loss: {overallLoss}\n")
        f.write("")
        f.close()
Exemple #6
0
class Evaluator(object):
    def __init__(self, args):
        self.args = args
        self.device = torch.device(args.device)

        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(cfg.DATASET.MEAN, cfg.DATASET.STD),
        ])

        # dataset and dataloader
        val_dataset = get_segmentation_dataset(cfg.DATASET.NAME, split='val', mode='testval', transform=input_transform)
        val_sampler = make_data_sampler(val_dataset, False, args.distributed)

        #####################
        # BATCH SIZE is always 1

        val_batch_sampler = make_batch_data_sampler(val_sampler, images_per_batch=cfg.TEST.BATCH_SIZE, drop_last=False)
        self.val_loader = data.DataLoader(dataset=val_dataset,
                                          batch_sampler=val_batch_sampler,
                                          num_workers=cfg.DATASET.WORKERS,
                                          pin_memory=True)
        self.classes = val_dataset.classes

        ### Create network ###

        # Segmentron model
        # self.model = get_segmentation_model().to(self.device)

        # MMSeg model
        mmseg_config_file = "mmseg-configs/deeplabv3plus_r101-d8_512x512_80k_ade20k.py"
        mmseg_pretrained = "pretrained_weights/deeplabv3plus_r101-d8_512x512_80k_ade20k_20200615_014139-d5730af7.pth"
        self.model = init_segmentor(mmseg_config_file, mmseg_pretrained)

        self.model.to(self.device)
        self.metric = SegmentationMetric(val_dataset.num_class, args.distributed)

    def set_batch_norm_attr(self, named_modules, attr, value):
        for m in named_modules:
            if isinstance(m[1], nn.BatchNorm2d) or isinstance(m[1], nn.SyncBatchNorm):
                setattr(m[1], attr, value)

    def eval(self):
        self.metric.reset()
        self.model.eval()
        if self.args.distributed:
            model = self.model.module
        else:
            model = self.model

        logging.info("Using Val/Test img scale : {}".format(cfg.TEST.IMG_SCALE))
        logging.info("Start validation, Total sample: {:d}".format(len(self.val_loader)))
        import time
        time_start = time.time()
        pbar = tqdm(self.val_loader)
        for image, target, filename in pbar:
            image = image.to(self.device)
            target = target.to(self.device)

            assert image.shape[0] == 1, "Only batch-size 1 allowed when evaluating on test/val images"

            with torch.no_grad():
                output = mmseg_evaluate(model, image, target)

            self.metric.update(output, target)
            pixAcc, mIoU = self.metric.get()

            pbar.set_postfix_str("pixAcc: {:.3f}, mIoU: {:.3f}".format(pixAcc * 100, mIoU * 100))

        synchronize()
        pixAcc, mIoU, category_iou = self.metric.get(return_category_iou=True)
        logging.info('Eval use time: {:.3f} second'.format(time.time() - time_start))
        logging.info('End validation pixAcc: {:.3f}, mIoU: {:.3f}'.format(pixAcc * 100, mIoU * 100))
Exemple #7
0
class Trainer(object):
    def __init__(self, args):
        self.args = args
        self.device = torch.device(args.device)

        # self.prefix = "randomtesting"
        self.prefix = "FCN_cce_alpha={}_sigma=0.5".format(cfg.TRAIN.ALPHA)
        # self.prefix = "testing_stuff={}".format(cfg.TRAIN.ALPHA)
        self.writer = SummaryWriter(
            log_dir=f"CCE_train_both_eval/{self.prefix}")
        self.writer_noisy = SummaryWriter(
            log_dir=f"CCE_train_both_eval/{self.prefix}-foggy")

        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(cfg.DATASET.MEAN, cfg.DATASET.STD),
        ])
        # dataset and dataloader
        data_kwargs = {
            'transform': input_transform,
            'base_size': cfg.TRAIN.BASE_SIZE,
            'crop_size': cfg.TRAIN.CROP_SIZE
        }

        # OUR TRAIN AND VAL ARE ALWAYS NORMAL CITYSCAPES
        train_dataset = get_segmentation_dataset(cfg.DATASET.NAME,
                                                 split='train',
                                                 mode='train',
                                                 **data_kwargs)
        val_dataset = get_segmentation_dataset(cfg.DATASET.NAME,
                                               split='val',
                                               mode=cfg.DATASET.MODE,
                                               **data_kwargs)
        self.classes = val_dataset.classes
        self.iters_per_epoch = len(train_dataset) // (args.num_gpus *
                                                      cfg.TRAIN.BATCH_SIZE)
        self.max_iters = cfg.TRAIN.EPOCHS * self.iters_per_epoch

        self.ece_evaluator = ECELoss(n_classes=len(self.classes))
        self.cce_evaluator = CCELoss(n_classes=len(self.classes))

        train_sampler = make_data_sampler(train_dataset,
                                          shuffle=True,
                                          distributed=args.distributed)
        train_batch_sampler = make_batch_data_sampler(train_sampler,
                                                      cfg.TRAIN.BATCH_SIZE,
                                                      self.max_iters,
                                                      drop_last=True)

        val_sampler = make_data_sampler(val_dataset, False, args.distributed)
        val_batch_sampler = make_batch_data_sampler(val_sampler,
                                                    cfg.TEST.BATCH_SIZE,
                                                    drop_last=False)

        self.train_loader = data.DataLoader(dataset=train_dataset,
                                            batch_sampler=train_batch_sampler,
                                            num_workers=cfg.DATASET.WORKERS,
                                            pin_memory=True)

        self.val_loader = data.DataLoader(dataset=val_dataset,
                                          batch_sampler=val_batch_sampler,
                                          num_workers=cfg.DATASET.WORKERS,
                                          pin_memory=True)

        # DEFINE data for noisy
        val_dataset_noisy = get_segmentation_dataset(cfg.DATASET.NOISY_NAME,
                                                     split='val',
                                                     mode=cfg.DATASET.MODE,
                                                     **data_kwargs)
        self.val_loader_noisy = data.DataLoader(
            dataset=val_dataset_noisy,
            batch_sampler=val_batch_sampler,
            num_workers=cfg.DATASET.WORKERS,
            pin_memory=True)

        # create network
        self.model = get_segmentation_model().to(self.device)

        # print (self.model)
        # import pdb; pdb,set_trace()
        for params in self.model.encoder.parameters():
            params.requires_grad = False

        # print params and flops
        if get_rank() == 0:
            try:
                show_flops_params(copy.deepcopy(self.model), args.device)
            except Exception as e:
                logging.warning('get flops and params error: {}'.format(e))

        if cfg.MODEL.BN_TYPE not in ['BN']:
            logging.info(
                'Batch norm type is {}, convert_sync_batchnorm is not effective'
                .format(cfg.MODEL.BN_TYPE))
        elif args.distributed and cfg.TRAIN.SYNC_BATCH_NORM:
            self.model = nn.SyncBatchNorm.convert_sync_batchnorm(self.model)
            logging.info('SyncBatchNorm is effective!')
        else:
            logging.info('Not use SyncBatchNorm!')

        # create criterion
        self.criterion = get_segmentation_loss(
            cfg.MODEL.MODEL_NAME,
            use_ohem=cfg.SOLVER.OHEM,
            aux=cfg.SOLVER.AUX,
            aux_weight=cfg.SOLVER.AUX_WEIGHT,
            ignore_index=cfg.DATASET.IGNORE_INDEX,
            n_classes=len(train_dataset.classes),
            alpha=cfg.TRAIN.ALPHA).to(self.device)

        # optimizer, for model just includes encoder, decoder(head and auxlayer).
        self.optimizer = get_optimizer(self.model)

        # lr scheduling
        self.lr_scheduler = get_scheduler(self.optimizer,
                                          max_iters=self.max_iters,
                                          iters_per_epoch=self.iters_per_epoch)

        # resume checkpoint if needed
        self.start_epoch = 0
        if args.resume and os.path.isfile(args.resume):
            name, ext = os.path.splitext(args.resume)
            assert ext == '.pkl' or '.pth', 'Sorry only .pth and .pkl files supported.'
            logging.info('Resuming training, loading {}...'.format(
                args.resume))
            resume_sate = torch.load(args.resume)
            self.model.load_state_dict(resume_sate['state_dict'])
            self.start_epoch = resume_sate['epoch']
            logging.info('resume train from epoch: {}'.format(
                self.start_epoch))
            if resume_sate['optimizer'] is not None and resume_sate[
                    'lr_scheduler'] is not None:
                logging.info(
                    'resume optimizer and lr scheduler from resume state..')
                self.optimizer.load_state_dict(resume_sate['optimizer'])
                self.lr_scheduler.load_state_dict(resume_sate['lr_scheduler'])

        if args.distributed:
            self.model = nn.parallel.DistributedDataParallel(
                self.model,
                device_ids=[args.local_rank],
                output_device=args.local_rank,
                find_unused_parameters=True)

        # evaluation metrics
        self.metric = SegmentationMetric(train_dataset.num_class,
                                         args.distributed)
        self.best_pred = 0.0

    def train(self):
        self.save_to_disk = get_rank() == 0
        epochs, max_iters, iters_per_epoch = cfg.TRAIN.EPOCHS, self.max_iters, self.iters_per_epoch
        log_per_iters, val_per_iters = self.args.log_iter, self.args.val_epoch * self.iters_per_epoch

        start_time = time.time()
        logging.info(
            'Start training, Total Epochs: {:d} = Total Iterations {:d}'.
            format(epochs, max_iters))

        self.model.train()
        iteration = self.start_epoch * iters_per_epoch if self.start_epoch > 0 else 0

        total_loss = 0
        for (images, targets, _) in self.train_loader:
            epoch = iteration // iters_per_epoch + 1
            iteration += 1

            images = images.to(self.device)
            targets = targets.to(self.device)

            outputs = self.model(images)
            loss_dict, loss_cal, loss_nll = self.criterion(outputs, targets)

            losses = sum(loss for loss in loss_dict.values())

            # reduce losses over all GPUs for logging purposes
            loss_dict_reduced = reduce_loss_dict(loss_dict)
            losses_reduced = sum(loss for loss in loss_dict_reduced.values())

            total_loss += losses_reduced.item()

            self.optimizer.zero_grad()
            losses.backward()
            self.optimizer.step()
            self.lr_scheduler.step()

            eta_seconds = ((time.time() - start_time) /
                           iteration) * (max_iters - iteration)
            eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))

            if iteration % log_per_iters == 0 and self.save_to_disk:
                logging.info(
                    "Epoch: {:d}/{:d} || Iters: {:d}/{:d} || Lr: {:.6f} || "
                    "Loss: {:.4f} || Cost Time: {} || Estimated Time: {}".
                    format(
                        epoch, epochs, iteration % iters_per_epoch,
                        iters_per_epoch, self.optimizer.param_groups[0]['lr'],
                        losses_reduced.item(),
                        str(
                            datetime.timedelta(seconds=int(time.time() -
                                                           start_time))),
                        eta_string))

                self.writer.add_scalar("Loss", losses_reduced.item(),
                                       iteration)
                self.writer.add_scalar("CCE oe ECE part of Loss",
                                       loss_cal.item(), iteration)
                self.writer.add_scalar("NLL Part", loss_nll.item(), iteration)
                self.writer.add_scalar("LR",
                                       self.optimizer.param_groups[0]['lr'],
                                       iteration)

            if iteration % self.iters_per_epoch == 0 and self.save_to_disk:
                save_checkpoint(self.model,
                                epoch,
                                self.optimizer,
                                self.lr_scheduler,
                                is_best=False)

            if not self.args.skip_val and iteration % val_per_iters == 0:
                self.validation(epoch, self.val_loader, self.writer)
                self.validation(epoch, self.val_loader_noisy,
                                self.writer_noisy)
                self.model.train()

        total_training_time = time.time() - start_time
        total_training_str = str(
            datetime.timedelta(seconds=total_training_time))
        logging.info("Total training time: {} ({:.4f}s / it)".format(
            total_training_str, total_training_time / max_iters))

    def validation(self, epoch, val_loader, writer):
        self.metric.reset()
        self.ece_evaluator.reset()
        self.cce_evaluator.reset()
        eceEvaluator_perimage = perimageCCE(n_classes=len(self.classes))

        if self.args.distributed:
            model = self.model.module
        else:
            model = self.model
        torch.cuda.empty_cache()
        model.eval()

        for i, (image, target, filename) in enumerate(val_loader):
            image = image.to(self.device)
            target = target.to(self.device)

            # print("dataset_mode",cfg.DATASET.MODE)
            # print("test+crop+soize", cfg.TEST.CROP_SIZE)
            with torch.no_grad():
                if cfg.DATASET.MODE == 'val' or cfg.TEST.CROP_SIZE is None:
                    output = model(image)[0]
                else:
                    size = image.size()[2:]
                    pad_height = cfg.TEST.CROP_SIZE[0] - size[0]
                    pad_width = cfg.TEST.CROP_SIZE[1] - size[1]
                    image = F.pad(image, (0, pad_height, 0, pad_width))
                    output = model(image)[0]
                    output = output[..., :size[0], :size[1]]

            # print(output.shape)
            # output is [1, 19, 1024, 2048] logits
            # target is [1, 1024, 2048]
            self.metric.update(output, target)

            if (i == 0):
                import cv2
                image_read = cv2.imread(filename[0])
                writer.add_image("Image[0] Read",
                                 image_read,
                                 epoch,
                                 dataformats="HWC")

                save_imgs = torch.softmax(output, dim=1)[0]
                for class_no, class_distri in enumerate(save_imgs):
                    plt.clf()
                    class_distri[0][0] = 0
                    class_distri[0][1] = 1

                    im = plt.imshow(class_distri.detach().cpu().numpy(),
                                    cmap="Greens")
                    plt.colorbar(im)
                    plt.savefig("temp_files/temp.jpg")
                    plt.clf()
                    import cv2
                    img_dif = cv2.imread("temp_files/temp.jpg")

                    writer.add_image(f"Class_{self.classes[class_no]}",
                                     img_dif,
                                     epoch,
                                     dataformats="HWC")

            with torch.no_grad():
                self.ece_evaluator.forward(output, target)
                self.cce_evaluator.forward(output, target)
                # for output, target in zip(output,target.detach()):
                #     #older ece requires softmax and size output=[class,w,h] target=[w,h]
                #     eceEvaluator_perimage.update(output.softmax(dim=0), target)

            pixAcc, mIoU = self.metric.get()
            logging.info(
                "[EVAL] Sample: {:d}, pixAcc: {:.3f}, mIoU: {:.3f}".format(
                    i + 1, pixAcc * 100, mIoU * 100))

        pixAcc, mIoU = self.metric.get()
        logging.info(
            "[EVAL END] Epoch: {:d}, pixAcc: {:.3f}, mIoU: {:.3f}".format(
                epoch, pixAcc * 100, mIoU * 100))
        writer.add_scalar("[EVAL END] pixAcc", pixAcc * 100, epoch)
        writer.add_scalar("[EVAL END] mIoU", mIoU * 100, epoch)

        # count_table_image, _ = eceEvaluator_perimage.get_count_table_img(self.classes)
        # cce_table_image, dif_map = eceEvaluator_perimage.get_perc_table_img(self.classes)

        ece_count_table_image, _ = self.ece_evaluator.get_count_table_img(
            self.classes)
        ece_table_image, ece_dif_map = self.ece_evaluator.get_perc_table_img(
            self.classes)

        cce_count_table_image, _ = self.cce_evaluator.get_count_table_img(
            self.classes)
        cce_table_image, cce_dif_map = self.cce_evaluator.get_perc_table_img(
            self.classes)

        # writer.add_image("CCE_table", cce_table_image, epoch, dataformats="HWC")
        # writer.add_image("CCE Count table", count_table_image, epoch, dataformats="HWC")
        # writer.add_image("CCE DifMap", dif_map, epoch, dataformats="HWC")
        # writer.add_scalar("CCE Score", eceEvaluator_perimage.get_overall_CCELoss(), epoch)

        writer.add_image("ece_table",
                         ece_table_image,
                         epoch,
                         dataformats="HWC")
        writer.add_image("ece Count table",
                         ece_count_table_image,
                         epoch,
                         dataformats="HWC")
        writer.add_image("ece DifMap", ece_dif_map, epoch, dataformats="HWC")
        writer.add_scalar("ece Score",
                          self.ece_evaluator.get_overall_ECELoss(), epoch)
        writer.add_scalar("ece dif Score", self.ece_evaluator.get_diff_score(),
                          epoch)

        writer.add_image("cce_table",
                         cce_table_image,
                         epoch,
                         dataformats="HWC")
        writer.add_image("cce Count table",
                         cce_count_table_image,
                         epoch,
                         dataformats="HWC")
        writer.add_image("cce DifMap", cce_dif_map, epoch, dataformats="HWC")
        writer.add_scalar("cce Score",
                          self.cce_evaluator.get_overall_CCELoss(), epoch)
        writer.add_scalar("cce dif Score", self.cce_evaluator.get_diff_score(),
                          epoch)
        synchronize()
        if self.best_pred < mIoU and self.save_to_disk:
            self.best_pred = mIoU
            logging.info(
                'Epoch {} is the best model, best pixAcc: {:.3f}, mIoU: {:.3f}, save the model..'
                .format(epoch, pixAcc * 100, mIoU * 100))
            save_checkpoint(model, epoch, is_best=True)
class Evaluator(object):
    def __init__(self, args):

        self.args = args
        self.device = torch.device(args.device)

        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(cfg.DATASET.MEAN, cfg.DATASET.STD),
        ])

        # dataset and dataloader
        val_dataset = get_segmentation_dataset(cfg.DATASET.NAME,
                                               split='val',
                                               mode='testval',
                                               transform=input_transform)
        val_sampler = make_data_sampler(val_dataset, False, args.distributed)
        val_batch_sampler = make_batch_data_sampler(
            val_sampler, images_per_batch=cfg.TEST.BATCH_SIZE, drop_last=False)
        self.val_loader = data.DataLoader(dataset=val_dataset,
                                          batch_sampler=val_batch_sampler,
                                          num_workers=cfg.DATASET.WORKERS,
                                          pin_memory=True)

        self.dataset = val_dataset
        self.classes = val_dataset.classes
        self.metric = SegmentationMetric(val_dataset.num_class,
                                         args.distributed)

        # DEFINE data for noisy
        val_dataset_noisy = get_segmentation_dataset(cfg.DATASET.NOISY_NAME,
                                                     split='val',
                                                     mode='testval',
                                                     transform=input_transform)
        val_sampler_noisy = make_data_sampler(val_dataset_noisy, False,
                                              args.distributed)
        val_batch_sampler_noisy = make_batch_data_sampler(
            val_sampler_noisy,
            images_per_batch=cfg.TEST.BATCH_SIZE,
            drop_last=False)
        self.val_loader_noisy = data.DataLoader(
            dataset=val_dataset_noisy,
            batch_sampler=val_batch_sampler_noisy,
            num_workers=cfg.DATASET.WORKERS,
            pin_memory=True)

        self.model = get_segmentation_model().to(self.device)

        if hasattr(self.model, 'encoder') and hasattr(self.model.encoder, 'named_modules') and \
            cfg.MODEL.BN_EPS_FOR_ENCODER:
            logging.info('set bn custom eps for bn in encoder: {}'.format(
                cfg.MODEL.BN_EPS_FOR_ENCODER))
            self.set_batch_norm_attr(self.model.encoder.named_modules(), 'eps',
                                     cfg.MODEL.BN_EPS_FOR_ENCODER)

        if args.distributed:
            self.model = nn.parallel.DistributedDataParallel(
                self.model,
                device_ids=[args.local_rank],
                output_device=args.local_rank,
                find_unused_parameters=True)

        self.model.to(self.device)

    def set_batch_norm_attr(self, named_modules, attr, value):
        for m in named_modules:
            if isinstance(m[1], nn.BatchNorm2d) or isinstance(
                    m[1], nn.SyncBatchNorm):
                setattr(m[1], attr, value)

    @torch.no_grad()
    def eval(self, val_loader, crf):
        self.metric.reset()
        self.model.eval()
        model = self.model

        logging.info("Start validation, Total sample: {:d}".format(
            len(val_loader)))
        import time
        time_start = time.time()

        for (image, target, filename) in tqdm(val_loader):
            image = image.to(self.device)
            target = target.to(self.device)

            # print(image.shape)

            output = model.evaluate(image)
            # output = torch.softmax(output, dim=1)

            # output /= 3

            # if use CRF
            filename = filename[0]
            # print(filename)
            raw_image = cv2.imread(filename, cv2.IMREAD_COLOR).astype(
                np.float32).transpose(2, 0, 1)
            raw_image = torch.from_numpy(raw_image).to(self.device)
            raw_image = raw_image.unsqueeze(dim=0)
            assert image.shape == raw_image.shape
            output = crf.forward(output, raw_image)

            # print(output.shape)
            self.metric.update(output, target)
            pixAcc, mIoU = self.metric.get()

        pixAcc, mIoU, category_iou = self.metric.get(return_category_iou=True)
        logging.info('Eval use time: {:.3f} second'.format(time.time() -
                                                           time_start))
        logging.info('End validation pixAcc: {:.3f}, mIoU: {:.3f}'.format(
            pixAcc * 100, mIoU * 100))

        return pixAcc * 100, mIoU * 100
Exemple #9
0
class Evaluator(object):
    def __init__(self, args):

        self.args = args
        self.device = torch.device(args.device)

        self.n_bins = 15
        self.ece_folder = "experiments/classCali/eceData"
        # self.postfix = "Conv13_PascalVOC_GPU"
        # self.postfix = "Min_Foggy_1_conv13_PascalVOC_GPU"
        self.postfix = "MINFoggy_1_conv13_PascalVOC_GPU"
        self.temp = 1.7
        # self.useCRF=False
        self.useCRF = True

        self.ece_criterion = metrics.IterativeECELoss()
        self.ece_criterion.make_bins(n_bins=self.n_bins)

        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(cfg.DATASET.MEAN, cfg.DATASET.STD),
        ])

        # dataset and dataloader
        val_dataset = get_segmentation_dataset(cfg.DATASET.NAME,
                                               split="val",
                                               mode="testval",
                                               transform=input_transform)
        val_sampler = make_data_sampler(val_dataset, False, args.distributed)
        val_batch_sampler = make_batch_data_sampler(
            val_sampler, images_per_batch=cfg.TEST.BATCH_SIZE, drop_last=False)
        self.val_loader = data.DataLoader(
            dataset=val_dataset,
            batch_sampler=val_batch_sampler,
            num_workers=cfg.DATASET.WORKERS,
            pin_memory=True,
        )

        self.dataset = val_dataset
        self.classes = val_dataset.classes
        self.metric = SegmentationMetric(val_dataset.num_class,
                                         args.distributed)

        self.model = get_segmentation_model().to(self.device)

        if (hasattr(self.model, "encoder")
                and hasattr(self.model.encoder, "named_modules")
                and cfg.MODEL.BN_EPS_FOR_ENCODER):
            logging.info("set bn custom eps for bn in encoder: {}".format(
                cfg.MODEL.BN_EPS_FOR_ENCODER))
            self.set_batch_norm_attr(self.model.encoder.named_modules(), "eps",
                                     cfg.MODEL.BN_EPS_FOR_ENCODER)

        if args.distributed:
            self.model = nn.parallel.DistributedDataParallel(
                self.model,
                device_ids=[args.local_rank],
                output_device=args.local_rank,
                find_unused_parameters=True,
            )

        self.model.to(self.device)

    def set_batch_norm_attr(self, named_modules, attr, value):
        for m in named_modules:
            if isinstance(m[1], nn.BatchNorm2d) or isinstance(
                    m[1], nn.SyncBatchNorm):
                setattr(m[1], attr, value)

    def giveComparisionImages_colormaps(self, pre_output, post_output,
                                        raw_image, gt_label, classes, outname):
        """
        pre_output-> [1,21,h,w] cuda tensor
        post_output-> [1,21,h,w] cuda tensor
        raw_image->[1,3,h,w] cuda tensor
        gt_label->[1,h,w] cuda tensor
        """
        metric = SegmentationMetric(nclass=21, distributed=False)
        metric.update(pre_output, gt_label)
        pre_pixAcc, pre_mIoU = metric.get()

        metric = SegmentationMetric(nclass=21, distributed=False)
        metric.update(post_output, gt_label)
        post_pixAcc, post_mIoU = metric.get()

        uncal_labels = np.unique(
            torch.argmax(pre_output.squeeze(0), dim=0).cpu().numpy())
        cal_labels = np.unique(
            torch.argmax(post_output.squeeze(0), dim=0).cpu().numpy())

        pre_label_map = torch.argmax(pre_output.squeeze(0),
                                     dim=0).cpu().numpy()
        post_label_map = torch.argmax(post_output.squeeze(0),
                                      dim=0).cpu().numpy()

        # Bringing the shapes to justice
        pre_output = pre_output.squeeze(0).cpu().numpy()
        post_output = post_output.squeeze(0).cpu().numpy()
        raw_image = raw_image.squeeze(0).permute(1, 2, 0).cpu().numpy().astype(
            np.uint8)
        gt_label = gt_label.squeeze(0).cpu().numpy()

        if False:
            pass
        else:
            # Show result for each class
            cols = int(np.ceil(
                (max(len(uncal_labels), len(cal_labels)) + 1))) + 1
            rows = 4

            plt.figure(figsize=(20, 20))
            # Plotting raw image
            ax = plt.subplot(rows, cols, 1)
            ax.set_title("Input image")
            ax.imshow(raw_image[:, :, ::-1])
            ax.axis("off")

            # Plottig GT
            ax = plt.subplot(rows, cols, cols + 1)
            ax.set_title("Difference MAP ")
            mask1 = get_color_pallete(pre_label_map, cfg.DATASET.NAME)
            mask2 = get_color_pallete(post_label_map, cfg.DATASET.NAME)
            # print(raw_image[:, :, ::-1].shape)
            ax.imshow(((pre_label_map != post_label_map).astype(np.uint8)))
            ax.axis("off")

            # Plottig GT
            ax = plt.subplot(rows, cols, 2 * cols + 1)
            ax.set_title("ColorMap (uncal+crf) pixA={:.4f} mIoU={:.4f}".format(
                pre_pixAcc, pre_mIoU))

            mask = get_color_pallete(pre_label_map, cfg.DATASET.NAME)
            ax.imshow(np.array(mask))
            ax.axis("off")

            # Plottig GT
            ax = plt.subplot(rows, cols, 3 * cols + 1)
            # metric = SegmentationMetric(nclass=21, distributed=False)
            # metric.update(pre_output, gt_label)
            # pixAcc, mIoU = metric.get()
            ax.set_title(
                "ColorMap (cal T = {} + CRF) pixA={:.4f} mIoU={:.4f}".format(
                    self.temp, post_pixAcc, post_mIoU))
            mask = get_color_pallete(post_label_map, cfg.DATASET.NAME)
            ax.imshow(np.array(mask))
            ax.axis("off")

            for i, label in enumerate(uncal_labels):
                ax = plt.subplot(rows, cols, i + 3)
                ax.set_title("Uncalibrated-" + classes[label])
                ax.imshow(pre_output[label], cmap="nipy_spectral")

                ax.axis("off")

            for i, label in enumerate(cal_labels):
                ax = plt.subplot(rows, cols, cols + i + 3)
                ax.set_title("Calibrated-" + classes[label])
                ax.imshow(post_output[label], cmap="nipy_spectral")
                ax.axis("off")

            for i, label in enumerate(cal_labels):
                ax = plt.subplot(rows, cols, 2 * cols + i + 3)

                min_dif = np.min(pre_output[label] - post_output[label])
                max_dif = np.max(pre_output[label] - post_output[label])

                dif_map = np.where(
                    (pre_output[label] - post_output[label]) > 0,
                    (pre_output[label] - post_output[label]),
                    0,
                )

                ax.set_title("decrease: " + classes[label] +
                             " max={:0.3f}".format(max_dif))
                ax.imshow(
                    dif_map / max_dif,
                    cmap="nipy_spectral",
                )
                ax.axis("off")

            for i, label in enumerate(cal_labels):
                ax = plt.subplot(rows, cols, 3 * cols + i + 3)

                min_dif = np.min(pre_output[label] - post_output[label])
                max_dif = np.max(pre_output[label] - post_output[label])

                dif_map = np.where(
                    (pre_output[label] - post_output[label]) < 0,
                    (pre_output[label] - post_output[label]),
                    0,
                )

                ax.set_title("increase: " + classes[label] +
                             " max={:0.3f}".format(-min_dif))
                ax.imshow(
                    dif_map / min_dif,
                    cmap="nipy_spectral",
                )
                ax.axis("off")

            plt.tight_layout()
            plt.savefig(outname)

    def giveComparisionImages_after_crf(self, pre_output, post_output,
                                        raw_image, gt_label, classes, outname):
        """
        pre_output-> [1,21,h,w] cuda tensor
        post_output-> [1,21,h,w] cuda tensor
        raw_image->[1,3,h,w] cuda tensor
        gt_label->[1,h,w] cuda tensor
        """
        uncal_labels = np.unique(
            torch.argmax(pre_output.squeeze(0), dim=0).cpu().numpy())
        cal_labels = np.unique(
            torch.argmax(post_output.squeeze(0), dim=0).cpu().numpy())

        # Bringing the shapes to justice
        pre_output = pre_output.squeeze(0).cpu().numpy()
        post_output = post_output.squeeze(0).cpu().numpy()
        raw_image = raw_image.squeeze(0).permute(1, 2, 0).cpu().numpy().astype(
            np.uint8)
        gt_label = gt_label.squeeze(0).cpu().numpy()

        # import pdb; pdb.set_trace()

        # gt_label=get_gt_with_id(imageName)
        # if(np.sum((cal_labelmap!=uncal_labelmap).astype(np.float32))==0):
        if False:
            pass
        else:
            # Show result for each class
            cols = int(np.ceil(
                (max(len(uncal_labels), len(cal_labels)) + 1))) + 1
            rows = 4

            plt.figure(figsize=(20, 20))
            ax = plt.subplot(rows, cols, 1)
            ax.set_title("Input image")
            ax.imshow(raw_image[:, :, ::-1])
            ax.axis("off")
            ax = plt.subplot(rows, cols, cols + 1)
            # @ neelabh remove this
            loss = 1.999999999999999
            ax.set_title("Difference Map")
            ax.imshow(raw_image[:, :, ::-1])
            ax.axis("off")
            # ax = plt.subplot(rows, cols, 2 * cols + 1)
            # gradient = np.linspace(0, 1, 256)
            # gradient = np.vstack((gradient, gradient))
            # ax.imshow(gradient, cmap="nipy_spectral")
            # ax.set_title("Acc")
            # ax.imshow(raw_image[:, :, ::-1])
            # ax.axis("off")

            for i, label in enumerate(uncal_labels):
                ax = plt.subplot(rows, cols, i + 3)
                ax.set_title("Uncalibrated + crf-" + classes[label])
                ax.imshow(pre_output[label], cmap="nipy_spectral")

                ax.axis("off")

            for i, label in enumerate(cal_labels):
                ax = plt.subplot(rows, cols, cols + i + 3)
                ax.set_title("Calibrated (T={}) + CRF ".format(self.temp) +
                             classes[label])
                ax.imshow(post_output[label], cmap="nipy_spectral")
                ax.axis("off")

            for i, label in enumerate(cal_labels):
                ax = plt.subplot(rows, cols, 2 * cols + i + 3)

                min_dif = np.min(pre_output[label] - post_output[label])
                max_dif = np.max(pre_output[label] - post_output[label])

                dif_map = np.where(
                    (pre_output[label] - post_output[label]) > 0,
                    (pre_output[label] - post_output[label]),
                    0,
                )

                ax.set_title("decrease: " + classes[label] +
                             " max={:0.3f}".format(max_dif))
                ax.imshow(
                    dif_map / max_dif,
                    cmap="nipy_spectral",
                )
                ax.axis("off")

            for i, label in enumerate(cal_labels):
                ax = plt.subplot(rows, cols, 3 * cols + i + 3)

                min_dif = np.min(pre_output[label] - post_output[label])
                max_dif = np.max(pre_output[label] - post_output[label])

                dif_map = np.where(
                    (pre_output[label] - post_output[label]) < 0,
                    (pre_output[label] - post_output[label]),
                    0,
                )

                ax.set_title("increase: " + classes[label] +
                             " max={:0.3f}".format(-min_dif))
                ax.imshow(
                    dif_map / min_dif,
                    cmap="nipy_spectral",
                )
                ax.axis("off")

            plt.tight_layout()
            plt.savefig(outname)

    def giveComparisionImages_before_crf(self, pre_output, post_output,
                                         raw_image, gt_label, classes,
                                         outname):
        """
        pre_output-> [1,21,h,w] cuda tensor
        post_output-> [1,21,h,w] cuda tensor
        raw_image->[1,3,h,w] cuda tensor
        gt_label->[1,h,w] cuda tensor
        """
        uncal_labels = np.unique(
            torch.argmax(pre_output.squeeze(0), dim=0).cpu().numpy())
        cal_labels = np.unique(
            torch.argmax(post_output.squeeze(0), dim=0).cpu().numpy())

        # Bringing the shapes to justice
        pre_output = pre_output.squeeze(0).cpu().numpy()
        post_output = post_output.squeeze(0).cpu().numpy()
        raw_image = raw_image.squeeze(0).permute(1, 2, 0).cpu().numpy().astype(
            np.uint8)
        gt_label = gt_label.squeeze(0).cpu().numpy()

        # import pdb; pdb.set_trace()

        # gt_label=get_gt_with_id(imageName)
        # if(np.sum((cal_labelmap!=uncal_labelmap).astype(np.float32))==0):
        if False:
            pass
        else:
            # Show result for each class
            cols = int(np.ceil(
                (max(len(uncal_labels), len(cal_labels)) + 1))) + 1
            rows = 4

            plt.figure(figsize=(20, 20))
            ax = plt.subplot(rows, cols, 1)
            ax.set_title("Input image")
            ax.imshow(raw_image[:, :, ::-1])
            ax.axis("off")
            ax = plt.subplot(rows, cols, cols + 1)
            # @ neelabh remove this
            loss = 1.999999999999999
            ax.set_title("Accuracy dif = {:0.3f}".format(loss))
            ax.imshow(raw_image[:, :, ::-1])
            ax.axis("off")
            # ax = plt.subplot(rows, cols, 2 * cols + 1)
            # gradient = np.linspace(0, 1, 256)
            # gradient = np.vstack((gradient, gradient))
            # ax.imshow(gradient, cmap="nipy_spectral")
            # ax.set_title("Acc")
            # ax.imshow(raw_image[:, :, ::-1])
            # ax.axis("off")

            for i, label in enumerate(uncal_labels):
                ax = plt.subplot(rows, cols, i + 3)
                ax.set_title("Uncalibrated-" + classes[label])
                ax.imshow(pre_output[label], cmap="nipy_spectral")

                ax.axis("off")

            for i, label in enumerate(cal_labels):
                ax = plt.subplot(rows, cols, cols + i + 3)
                ax.set_title("Calibrated (T = {}) ".format(self.temp) +
                             classes[label])
                ax.imshow(post_output[label], cmap="nipy_spectral")
                ax.axis("off")

            for i, label in enumerate(cal_labels):
                ax = plt.subplot(rows, cols, 2 * cols + i + 3)

                min_dif = np.min(pre_output[label] - post_output[label])
                max_dif = np.max(pre_output[label] - post_output[label])

                dif_map = np.where(
                    (pre_output[label] - post_output[label]) > 0,
                    (pre_output[label] - post_output[label]),
                    0,
                )

                ax.set_title("decrease: " + classes[label] +
                             " max={:0.3f}".format(max_dif))
                ax.imshow(
                    dif_map / max_dif,
                    cmap="nipy_spectral",
                )
                ax.axis("off")

            for i, label in enumerate(cal_labels):
                ax = plt.subplot(rows, cols, 3 * cols + i + 3)

                min_dif = np.min(pre_output[label] - post_output[label])
                max_dif = np.max(pre_output[label] - post_output[label])

                dif_map = np.where(
                    (pre_output[label] - post_output[label]) < 0,
                    (pre_output[label] - post_output[label]),
                    0,
                )

                ax.set_title("increase: " + classes[label] +
                             " max={:0.3f}".format(-min_dif))
                ax.imshow(
                    dif_map / min_dif,
                    cmap="nipy_spectral",
                )
                ax.axis("off")

            plt.tight_layout()
            plt.savefig(outname)

    def eceOperations(self,
                      endNAme,
                      bin_total,
                      bin_total_correct,
                      bin_conf_total,
                      temp=None):
        eceLoss = self.ece_criterion.get_interative_loss(
            bin_total, bin_total_correct, bin_conf_total)
        # print('ECE with probabilties %f' % (eceLoss))
        if temp == None:
            temp = self.temp

        saveDir = os.path.join(self.ece_folder, self.postfix + f"_temp={temp}")
        makedirs(saveDir)

        file = open(os.path.join(saveDir, "Results.txt"), "a")
        file.write(
            f"{endNAme.strip('.npy')}_temp={temp}\t\t\t ECE Loss: {eceLoss}\n")

        plot_folder = os.path.join(saveDir, "plots")
        makedirs(plot_folder)

        rel_diagram = visualization.ReliabilityDiagramIterative()
        plt_test_2 = rel_diagram.plot(bin_total,
                                      bin_total_correct,
                                      bin_conf_total,
                                      title="Reliability Diagram")
        plt_test_2.savefig(
            os.path.join(plot_folder,
                         f'{endNAme.strip(".npy")}_temp={temp}.png'),
            bbox_inches="tight",
        )
        plt_test_2.close()
        return eceLoss

    def give_ece_order(self, model):
        """
        Performs evaluation over the entire daatset
        Returns a array of [imageName, eceLoss] in sorted order (descending
        """

        eceLosses = []
        for (image, target, filename) in tqdm(self.val_loader):
            bin_total = []
            bin_total_correct = []
            bin_conf_total = []
            image = image.to(self.device)
            target = target.to(self.device)

            filename = filename[0]
            # print(filename)
            endName = os.path.basename(filename).replace(".jpg", ".npy")
            # print(endName)

            npy_target_directory = "datasets/VOC_targets"
            npy_file = os.path.join(npy_target_directory, endName)
            if os.path.isfile(npy_file):
                pass
            else:
                makedirs(npy_target_directory)
                np.save(npy_file, target.cpu().numpy())
                # print("Npy files not found | Going for onboard eval")

            # print(image.shape)
            with torch.no_grad():

                # Checking if npy preprocesssed exists or not

                # print(filename)
                # npy_output_directory = "npy_outputs/npy_VOC_outputs"
                npy_output_directory = "npy_outputs/npy_foggy1_VOC_outputs"
                npy_file = os.path.join(npy_output_directory, endName)

                # print (npy_file)
                if os.path.isfile(npy_file):
                    output = np.load(npy_file)
                    output = torch.Tensor(output).cuda()
                    # print("Reading Numpy Files")
                else:
                    # print("Npy files not found | Going for onboard eval")
                    makedirs(npy_output_directory)
                    output = model.evaluate(image)
                    np.save(npy_file, output.cpu().numpy())

                output_before_cali = output.clone()

                # ECE Stuff
                conf = np.max(output_before_cali.softmax(dim=1).cpu().numpy(),
                              axis=1)
                label = torch.argmax(output_before_cali, dim=1).cpu().numpy()
                # print(conf.shape,label.shape,target.shape)
                (
                    bin_total_current,
                    bin_total_correct_current,
                    bin_conf_total_current,
                ) = self.ece_criterion.get_collective_bins(
                    conf, label,
                    target.cpu().numpy())
                # import pdb; pdb.set_trace()
                bin_total.append(bin_total_current)
                bin_total_correct.append(bin_total_correct_current)
                bin_conf_total.append(bin_conf_total_current)

                # ECE stuff
                # if(not self.useCRF):
                eceLosses.append([
                    endName,
                    filename,
                    self.eceOperations(
                        endName,
                        bin_total,
                        bin_total_correct,
                        bin_conf_total,
                        temp=1,
                    ),
                ])

        eceLosses.sort(key=lambda x: x[2], reverse=True)

        return eceLosses

    def eval(self):
        self.metric.reset()
        self.model.eval()
        model = self.model

        logging.info("Start validation, Total sample: {:d}".format(
            len(self.val_loader)))
        import time

        time_start = time.time()
        # if(not self.useCRF):

        # first loop for finding ece errors
        if os.path.isfile("experiments/classCali/sorted_ecefoggy.pickle"):
            file = open("experiments/classCali/sorted_ecefoggy.pickle", "rb")
            # if os.path.isfile("experiments/classCali/sorted_ece.pickle"):
            #     file = open("experiments/classCali/sorted_ece.pickle", "rb")
            eceLosses = pickle.load(file)
            file.close()
        else:
            assert False
            eceLosses = self.give_ece_order(model)
            pickle.dump(eceLosses,
                        open("experiments/classCali/sorted_ece.pickle", "wb"))

        print("ECE sorting completed....")

        top_k = 10

        assert top_k > 0
        eceLosses.reverse()
        # for i, (endName, imageLoc, eceLoss) in enumerate(tqdm(eceLosses[2:3])):
        for i, (endName, imageLoc,
                eceLoss) in enumerate(tqdm(eceLosses[:top_k])):
            # Loading outputs
            print(endName)
            # npy_output_directory = "npy_outputs/npy_VOC_outputs"
            npy_output_directory = "npy_outputs/npy_foggy1_VOC_outputs"
            npy_file = os.path.join(npy_output_directory, endName)
            output = np.load(npy_file)
            output = torch.Tensor(output).cuda()

            # loading targets
            npy_target_directory = "datasets/VOC_targets"
            npy_file = os.path.join(npy_target_directory, endName)
            target = np.load(npy_file)
            target = torch.Tensor(target).cuda()

            # print(image.shape)
            with torch.no_grad():

                output_uncal = output.clone()
                output_cal = output / self.temp

                # ECE Stuff
                bin_total = []
                bin_total_correct = []
                bin_conf_total = []
                conf = np.max(output_uncal.softmax(dim=1).cpu().numpy(),
                              axis=1)
                label = torch.argmax(output_uncal, dim=1).cpu().numpy()
                # print(conf.shape,label.shape,target.shape)
                (
                    bin_total_current,
                    bin_total_correct_current,
                    bin_conf_total_current,
                ) = self.ece_criterion.get_collective_bins(
                    conf, label,
                    target.cpu().numpy())
                # import pdb; pdb.set_trace()
                bin_total.append(bin_total_current)
                bin_total_correct.append(bin_total_correct_current)
                bin_conf_total.append(bin_conf_total_current)

                # ECE stuff
                # if(not self.useCRF):
                self.eceOperations(endName,
                                   bin_total,
                                   bin_total_correct,
                                   bin_conf_total,
                                   temp=1)

                # ECE Stuff
                bin_total = []
                bin_total_correct = []
                bin_conf_total = []
                conf = np.max(output_cal.softmax(dim=1).cpu().numpy(), axis=1)
                label = torch.argmax(output_cal, dim=1).cpu().numpy()
                # print(conf.shape,label.shape,target.shape)
                (
                    bin_total_current,
                    bin_total_correct_current,
                    bin_conf_total_current,
                ) = self.ece_criterion.get_collective_bins(
                    conf, label,
                    target.cpu().numpy())
                # import pdb; pdb.set_trace()
                bin_total.append(bin_total_current)
                bin_total_correct.append(bin_total_correct_current)
                bin_conf_total.append(bin_conf_total_current)

                # ECE stuff
                # if(not self.useCRF):
                self.eceOperations(
                    endName,
                    bin_total,
                    bin_total_correct,
                    bin_conf_total,
                )

                # REad raw image

                raw_image = (cv2.imread(imageLoc, cv2.IMREAD_COLOR).astype(
                    np.float32).transpose(2, 0, 1))
                raw_image = torch.from_numpy(raw_image).to(self.device)
                raw_image = raw_image.unsqueeze(dim=0)

                # Setting up CRF
                crf = GaussCRF(
                    conf=get_default_conf(),
                    shape=output.shape[2:],
                    nclasses=len(self.classes),
                    use_gpu=True,
                )
                crf = crf.to(self.device)

                # Getting CRF outputs
                # print(output.shape, raw_image.shape)
                assert output.shape[2:] == raw_image.shape[2:]
                # import pdb; pdb.set_trace()
                # print(":here1:")
                output_cal_crf = crf.forward(output_cal, raw_image)
                # print(":here2:")
                output_uncal_crf = crf.forward(output_uncal, raw_image)

                # Comparision before CRF bw cali and uncali
                comparisionFolder = "experiments/classCali/comparisionImages"
                saveFolder = os.path.join(
                    comparisionFolder,
                    "bcrf" + self.postfix + f"_temp={self.temp}")
                makedirs(saveFolder)
                saveName = os.path.join(saveFolder, os.path.basename(imageLoc))
                self.giveComparisionImages_before_crf(
                    output_uncal.softmax(dim=1),
                    output_cal.softmax(dim=1),
                    raw_image,
                    target,
                    self.classes,
                    saveName,
                )

                # Comparision before CRF bw cali and uncali
                comparisionFolder = "experiments/classCali/comparisionImages"
                saveFolder = os.path.join(
                    comparisionFolder,
                    "crf" + self.postfix + f"_temp={self.temp}")
                makedirs(saveFolder)
                saveName = os.path.join(saveFolder, os.path.basename(imageLoc))
                self.giveComparisionImages_after_crf(
                    output_uncal_crf.softmax(dim=1),
                    output_cal_crf.softmax(dim=1),
                    raw_image,
                    target,
                    self.classes,
                    saveName,
                )

                # Comparision uncali vs  CRF after cali
                comparisionFolder = "experiments/classCali/comparisionImages"
                saveFolder = os.path.join(
                    comparisionFolder,
                    "cmap_" + self.postfix + f"_temp={self.temp}")
                makedirs(saveFolder)
                saveName = os.path.join(saveFolder, os.path.basename(imageLoc))
                self.giveComparisionImages_colormaps(
                    output_uncal_crf.softmax(dim=1),
                    output_cal_crf.softmax(dim=1),
                    raw_image,
                    target,
                    self.classes,
                    saveName,
                )
Exemple #10
0
class Trainer(object):
    def __init__(self, args):
        self.args = args
        self.device = torch.device(args.device)
        self.use_fp16 = cfg.TRAIN.APEX

        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(cfg.DATASET.MEAN, cfg.DATASET.STD),
        ])
        # dataset and dataloader
        data_kwargs = {
            'transform': input_transform,
            'base_size': cfg.TRAIN.BASE_SIZE,
            'crop_size': cfg.TRAIN.CROP_SIZE
        }

        data_kwargs_testval = {
            'transform': input_transform,
            'base_size': cfg.TRAIN.BASE_SIZE,
            'crop_size': cfg.TEST.CROP_SIZE
        }

        train_dataset = get_segmentation_dataset(cfg.DATASET.NAME,
                                                 split='train',
                                                 mode='train',
                                                 **data_kwargs)
        val_dataset = get_segmentation_dataset(cfg.DATASET.NAME,
                                               split='val',
                                               mode='testval',
                                               **data_kwargs_testval)
        test_dataset = get_segmentation_dataset(cfg.DATASET.NAME,
                                                split='test',
                                                mode='testval',
                                                **data_kwargs_testval)

        self.classes = test_dataset.classes

        self.iters_per_epoch = len(train_dataset) // (args.num_gpus *
                                                      cfg.TRAIN.BATCH_SIZE)
        self.max_iters = cfg.TRAIN.EPOCHS * self.iters_per_epoch

        train_sampler = make_data_sampler(train_dataset,
                                          shuffle=True,
                                          distributed=args.distributed)
        train_batch_sampler = make_batch_data_sampler(train_sampler,
                                                      cfg.TRAIN.BATCH_SIZE,
                                                      self.max_iters,
                                                      drop_last=True)

        val_sampler = make_data_sampler(val_dataset, False, args.distributed)
        val_batch_sampler = make_batch_data_sampler(val_sampler,
                                                    cfg.TEST.BATCH_SIZE,
                                                    drop_last=False)

        test_sampler = make_data_sampler(test_dataset, False, args.distributed)
        test_batch_sampler = make_batch_data_sampler(test_sampler,
                                                     cfg.TEST.BATCH_SIZE,
                                                     drop_last=False)

        self.train_loader = data.DataLoader(dataset=train_dataset,
                                            batch_sampler=train_batch_sampler,
                                            num_workers=cfg.DATASET.WORKERS,
                                            pin_memory=True)
        self.val_loader = data.DataLoader(dataset=val_dataset,
                                          batch_sampler=val_batch_sampler,
                                          num_workers=cfg.DATASET.WORKERS,
                                          pin_memory=True)
        self.test_loader = data.DataLoader(dataset=test_dataset,
                                           batch_sampler=test_batch_sampler,
                                           num_workers=cfg.DATASET.WORKERS,
                                           pin_memory=True)

        # create network
        self.model = get_segmentation_model().to(self.device)

        # print params and flops
        if get_rank() == 0:
            try:
                show_flops_params(copy.deepcopy(self.model), args.device)
            except Exception as e:
                logging.warning('get flops and params error: {}'.format(e))
        if cfg.MODEL.BN_TYPE not in ['BN']:
            logging.info(
                'Batch norm type is {}, convert_sync_batchnorm is not effective'
                .format(cfg.MODEL.BN_TYPE))
        elif args.distributed and cfg.TRAIN.SYNC_BATCH_NORM:
            self.model = nn.SyncBatchNorm.convert_sync_batchnorm(self.model)
            logging.info('SyncBatchNorm is effective!')
        else:
            logging.info('Not use SyncBatchNorm!')
        # create criterion
        self.criterion = get_segmentation_loss(
            cfg.MODEL.MODEL_NAME,
            use_ohem=cfg.SOLVER.OHEM,
            aux=cfg.SOLVER.AUX,
            aux_weight=cfg.SOLVER.AUX_WEIGHT,
            ignore_index=cfg.DATASET.IGNORE_INDEX).to(self.device)
        # optimizer, for model just includes encoder, decoder(head and auxlayer).
        self.optimizer = get_optimizer(self.model)
        # apex
        if self.use_fp16:
            self.model, self.optimizer = apex.amp.initialize(self.model.cuda(),
                                                             self.optimizer,
                                                             opt_level="O1")
            logging.info('**** Initializing mixed precision done. ****')

        # lr scheduling
        self.lr_scheduler = get_scheduler(self.optimizer,
                                          max_iters=self.max_iters,
                                          iters_per_epoch=self.iters_per_epoch)
        # resume checkpoint if needed
        self.start_epoch = 0
        if args.resume and os.path.isfile(args.resume):
            name, ext = os.path.splitext(args.resume)
            assert ext == '.pkl' or '.pth', 'Sorry only .pth and .pkl files supported.'
            logging.info('Resuming training, loading {}...'.format(
                args.resume))
            resume_sate = torch.load(args.resume)
            self.model.load_state_dict(resume_sate['state_dict'])
            self.start_epoch = resume_sate['epoch']
            logging.info('resume train from epoch: {}'.format(
                self.start_epoch))
            if resume_sate['optimizer'] is not None and resume_sate[
                    'lr_scheduler'] is not None:
                logging.info(
                    'resume optimizer and lr scheduler from resume state..')
                self.optimizer.load_state_dict(resume_sate['optimizer'])
                self.lr_scheduler.load_state_dict(resume_sate['lr_scheduler'])

        if args.distributed:
            self.model = nn.parallel.DistributedDataParallel(
                self.model,
                device_ids=[args.local_rank],
                output_device=args.local_rank,
                find_unused_parameters=True)
        # evaluation metrics
        self.metric = SegmentationMetric(train_dataset.num_class,
                                         args.distributed)

    def train(self):
        self.save_to_disk = get_rank() == 0
        epochs, max_iters, iters_per_epoch = cfg.TRAIN.EPOCHS, self.max_iters, self.iters_per_epoch
        log_per_iters, val_per_iters = self.args.log_iter, self.args.val_epoch * self.iters_per_epoch

        start_time = time.time()
        logging.info(
            'Start training, Total Epochs: {:d} = Total Iterations {:d}'.
            format(epochs, max_iters))

        self.model.train()
        iteration = self.start_epoch * iters_per_epoch if self.start_epoch > 0 else 0
        for (images, targets, _) in self.train_loader:
            epoch = iteration // iters_per_epoch + 1
            iteration += 1

            images = images.to(self.device)
            targets = targets.to(self.device)

            outputs = self.model(images)
            loss_dict = self.criterion(outputs, targets)
            losses = sum(loss for loss in loss_dict.values())
            # reduce losses over all GPUs for logging purposes
            loss_dict_reduced = reduce_loss_dict(loss_dict)
            losses_reduced = sum(loss for loss in loss_dict_reduced.values())

            self.optimizer.zero_grad()
            if self.use_fp16:
                with apex.amp.scale_loss(losses,
                                         self.optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                losses.backward()
            self.optimizer.step()
            self.lr_scheduler.step()

            eta_seconds = ((time.time() - start_time) /
                           iteration) * (max_iters - iteration)
            eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
            if iteration % log_per_iters == 0 and self.save_to_disk:
                logging.info(
                    "Epoch: {:d}/{:d} || Iters: {:d}/{:d} || Lr: {:.6f} || "
                    "Loss: {:.4f} || Cost Time: {} || Estimated Time: {}".
                    format(
                        epoch, epochs, iteration % iters_per_epoch,
                        iters_per_epoch, self.optimizer.param_groups[0]['lr'],
                        losses_reduced.item(),
                        str(
                            datetime.timedelta(seconds=int(time.time() -
                                                           start_time))),
                        eta_string))

            if iteration % self.iters_per_epoch == 0 and self.save_to_disk:
                save_checkpoint(self.model,
                                epoch,
                                self.optimizer,
                                self.lr_scheduler,
                                is_best=False)

            if not self.args.skip_val and iteration % val_per_iters == 0:
                self.validation(epoch)
                self.test()
                self.model.train()

        total_training_time = time.time() - start_time
        total_training_str = str(
            datetime.timedelta(seconds=total_training_time))
        logging.info("Total training time: {} ({:.4f}s / it)".format(
            total_training_str, total_training_time / max_iters))

    def validation(self, epoch):
        self.metric.reset()
        if self.args.distributed:
            model = self.model.module
        else:
            model = self.model
        torch.cuda.empty_cache()
        model.eval()
        for i, (image, target, filename) in enumerate(self.val_loader):
            image = image.to(self.device)
            target = target.to(self.device)

            with torch.no_grad():
                if cfg.DATASET.MODE == 'val' or cfg.TEST.CROP_SIZE is None:
                    output = model(image)[0]
                else:
                    size = image.size()[2:]
                    assert cfg.TEST.CROP_SIZE[0] == size[0]
                    assert cfg.TEST.CROP_SIZE[1] == size[1]
                    output = model(image)[0]

            self.metric.update(output, target)
            pixAcc, mIoU, category_iou = self.metric.get(
                return_category_iou=True)
            logging.info(
                "[EVAL] Sample: {:d}, pixAcc: {:.3f}, mIoU: {:.3f}".format(
                    i + 1, pixAcc * 100, mIoU * 100))
        pixAcc, mIoU = self.metric.get()
        logging.info(
            "[EVAL END] Epoch: {:d}, pixAcc: {:.3f}, mIoU: {:.3f}".format(
                epoch, pixAcc * 100, mIoU * 100))
        synchronize()

    def test(self, vis=False):
        self.metric.reset()
        if self.args.distributed:
            model = self.model.module
        else:
            model = self.model
        torch.cuda.empty_cache()
        model.eval()
        for i, (image, target, filename) in enumerate(self.test_loader):
            image = image.to(self.device)
            target = target.to(self.device)

            with torch.no_grad():
                if cfg.DATASET.MODE == 'test' or cfg.TEST.CROP_SIZE is None:
                    output = model(image)[0]
                else:
                    size = image.size()[2:]
                    assert cfg.TEST.CROP_SIZE[0] == size[0]
                    assert cfg.TEST.CROP_SIZE[1] == size[1]
                    output = model(image)[0]

            if vis:
                save_gt = False
                if save_gt:
                    test_path = '/mnt/lustre/xieenze/xez_space/TransparentSeg/datasets/transparent/Trans10K_cls12/test/images'
                    save_path = 'workdirs/trans10kv2/gt_img'
                    if not os.path.exists(save_path):
                        os.makedirs(save_path)
                    gt_img = Image.open(os.path.join(test_path,
                                                     filename[0])).resize(
                                                         (512, 512))
                    gt_img.save(os.path.join(save_path, str(i) + '.png'))

                    gt_mask = target[0].data.cpu().numpy()
                    vis_gt = get_color_pallete(gt_mask, dataset='trans10kv2')
                    save_path = 'workdirs/trans10kv2/gt_mask'
                    if not os.path.exists(save_path):
                        os.makedirs(save_path)
                    vis_gt.save(os.path.join(save_path, str(i) + '.png'))
                else:
                    vis_pred = output[0].permute(
                        1, 2, 0).argmax(-1).data.cpu().numpy()
                    vis_pred = get_color_pallete(vis_pred,
                                                 dataset='trans10kv2')
                    save_path = os.path.join(cfg.TRAIN.MODEL_SAVE_DIR, 'vis')
                    if not os.path.exists(save_path):
                        os.makedirs(save_path)
                    vis_pred.save(os.path.join(save_path, str(i) + '.png'))
                print("[VIS TEST] Sample: {:d}".format(i + 1))
                continue

            self.metric.update(output, target)
            pixAcc, mIoU, category_iou = self.metric.get(
                return_category_iou=True)
            logging.info(
                "[TEST] Sample: {:d}, pixAcc: {:.3f}, mIoU: {:.3f}".format(
                    i + 1, pixAcc * 100, mIoU * 100))

        synchronize()
        pixAcc, mIoU, category_iou = self.metric.get(return_category_iou=True)
        logging.info("[TEST END]  pixAcc: {:.3f}, mIoU: {:.3f}".format(
            pixAcc * 100, mIoU * 100))

        headers = ['class id', 'class name', 'iou']
        table = []
        for i, cls_name in enumerate(self.classes):
            table.append([cls_name, category_iou[i]])
        logging.info('Category iou: \n {}'.format(
            tabulate(table,
                     headers,
                     tablefmt='grid',
                     showindex="always",
                     numalign='center',
                     stralign='center')))