Example #1
0
class Trainer(object):
    def __init__(self, args):
        self.args = args
        self.vs = Vs(args.dataset)

        # Define Dataloader
        kwargs = {"num_workers": args.workers, "pin_memory": True}
        (
            self.train_loader,
            self.val_loader,
            self.test_loader,
            self.nclass,
        ) = make_data_loader(args, **kwargs)

        if self.args.norm == "gn":
            norm = gn
        elif self.args.norm == "bn":
            if self.args.sync_bn:
                norm = syncbn
            else:
                norm = bn
        elif self.args.norm == "abn":
            if self.args.sync_bn:
                norm = syncabn(self.args.gpu_ids)
            else:
                norm = abn
        else:
            print("Please check the norm.")
            exit()

        # Define network
        if self.args.model == "deeplabv3+":
            model = DeepLab(args=self.args,
                            num_classes=self.nclass,
                            freeze_bn=args.freeze_bn)
        elif self.args.model == "deeplabv3":
            model = DeepLabv3(
                Norm=args.norm,
                backbone=args.backbone,
                output_stride=args.out_stride,
                num_classes=self.nclass,
                freeze_bn=args.freeze_bn,
            )
        elif self.args.model == "fpn":
            model = FPN(args=args, num_classes=self.nclass)

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(
                Path.db_root_dir(args.dataset),
                args.dataset + "_classes_weights.npy")
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset,
                                                  self.train_loader,
                                                  self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(
            weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model = model

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint["epoch"]
            if args.cuda:
                self.model.module.load_state_dict(checkpoint["state_dict"])
            else:
                self.model.load_state_dict(checkpoint["state_dict"])
            self.best_pred = checkpoint["best_pred"]
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint["epoch"]))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def test(self):
        self.model.eval()
        self.args.examine = False
        tbar = tqdm(self.test_loader, desc="\r")
        if self.args.color:
            __image = True
        else:
            __image = False
        for i, sample in enumerate(tbar):
            images = sample["image"]
            names = sample["name"]
            if self.args.cuda:
                images = images.cuda()
            with torch.no_grad():
                output = self.model(images)
            preds = output.data.cpu().numpy()
            preds = np.argmax(preds, axis=1)
            if __image:
                images = images.cpu().numpy()
            if not self.args.color:
                self.vs.predict_id(preds, names, self.args.save_dir)
            else:
                self.vs.predict_color(preds, images, names, self.args.save_dir)

    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc="\r")
        test_loss = 0.0
        if self.args.color or self.args.examine:
            __image = True
        else:
            __image = False
        for i, sample in enumerate(tbar):
            images, targets = sample["image"], sample["label"]
            names = sample["name"]
            if self.args.cuda:
                images, targets = images.cuda(), targets.cuda()
            with torch.no_grad():
                output = self.model(images)
            loss = self.criterion(output, targets)
            test_loss += loss.item()
            tbar.set_description("Test loss: %.3f" % (test_loss / (i + 1)))
            preds = output.data.cpu().numpy()
            targets = targets.cpu().numpy()
            preds = np.argmax(preds, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(targets, preds)
            if __image:
                images = images.cpu().numpy()
            if self.args.id:
                self.vs.predict_id(preds, names, self.args.save_dir)
            if self.args.color:
                self.vs.predict_color(preds, images, names, self.args.save_dir)
            if self.args.examine:
                self.vs.predict_examine(preds, targets, images, names,
                                        self.args.save_dir)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        print("Validation:")
        # print(
        #     "[Epoch: %d, numImages: %5d]"
        #     % (epoch, i * self.args.batch_size + image.data.shape[0])
        # )
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
            Acc, Acc_class, mIoU, FWIoU))
        print("Loss: %.3f" % test_loss)
Example #2
0
class Tester(object):
    def __init__(self, args, verbose=True):
        self.args = args
        self.verbose = verbose

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclasses = make_data_loader(
            args, verbose)

        if self.args.task == 'segmentation':
            self.vs = Vs(args.dataset)
            self.evaluator = Evaluator(self.nclasses['val'])

        # Define Network
        model = Model(args, self.nclasses['train'], self.nclasses['test'])

        # Define Criterion
        criterion = nn.CrossEntropyLoss(ignore_index=args.ignore_index)

        self.model, self.criterion = model, criterion

        # Loading Classifier (SPNet style)
        if args.call is None or args.cseen is None or args.cunseen is None:
            raise NotImplementedError(
                "Classifiers for 'all', 'seen', 'unseen' should be loaded")
        else:
            if args.test_set == 'unseen':
                ctest = args.cunseen
            elif args.test_set == 'all':
                ctest = args.call
            elif args.test_set == 'seen':
                ctest = args.cseen
            else:
                raise RuntimeError("{}".format(args.test_set))

            if not os.path.isfile(ctest):
                raise RuntimeError(
                    "=> no checkpoint for clasifier found at '{}'".format(
                        ctest))

            self.model.load_test(ctest)

            if verbose:
                print("Classifiers checkpoint successfully loaded from {}, {}".
                      format(args.cseen, ctest))

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("{}: No such checkpoint exists".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']

            if args.cuda:
                pretrained_dict = checkpoint['state_dict']
                model_dict = {}
                state_dict = self.model.state_dict()
                for k, v in pretrained_dict.items():
                    if 'classifier' in k: continue
                    if k in state_dict:
                        model_dict[k] = v
                state_dict.update(model_dict)
                self.model.load_state_dict(state_dict)
            else:
                print("Please use CUDA")
                raise NotImplementedError

            self.best_pred = checkpoint['best_pred']
            if verbose:
                print("Loading {} (epoch {}) successfully done".format(
                    args.resume, checkpoint['epoch']))

        # Using CUDA
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            self.model = self.model.cuda()

        if args.ft:
            args.start_epoch = 0

    def test(self):
        self.model.eval()

        tbar = tqdm(self.test_loader)
        logits = [0.0] * self.nclasses['test']

        for i, sample in enumerate(tbar):
            images = sample['image'].cuda()
            names = sample['name']
            falses = torch.from_numpy(np.array([False] *
                                               images.shape[0])).cuda()

            with torch.no_grad():
                output = self.model(images, falses)

            preds_np = output.cpu().numpy()
            for i in range(preds_np.shape[1]):
                logits[i] = np.mean(preds_np[:, i, :, :])

            preds = torch.argmax(output, axis=1)

            if self.args.id:
                self.vs.predict_id(preds, names, self.args.save_dir)
            if self.args.color:
                self.vs.predict_color(preds, images, names, self.args.save_dir)
            #tbar.set_description("{} : {}".format(names[0], self.id2class(preds[0])))
        print(logits)

    def val(self):
        if self.args.task == 'classification':
            top1 = AverageMeter('Acc@1', ':6.2f')
            top5 = AverageMeter('Acc@5', ':6.2f')
        elif self.args.task == 'segmentation':
            self.evaluator.reset()

        if self.args.id or self.args.color or self.args.examine:
            if not os.path.exists(self.args.save_dir):
                os.makedirs(self.args.save_dir)

        self.model.eval()
        tbar = tqdm(self.test_loader)
        miou = 0.0
        count = 0
        for i, sample in enumerate(tbar):
            images, targets, names = sample['image'].cuda(
            ), sample['label'].cuda().long(), sample['name']
            falses = torch.from_numpy(np.array([False] *
                                               images.shape[0])).cuda()
            with torch.no_grad():
                outputs = self.model(images, falses)

            # Score record
            if self.args.task == 'classification':
                acc1, acc5 = accuracy(outputs, targets, topk=(1, 5))
                top1.update(acc1[0], images.size(0))
                top5.update(acc5[0], images.size(0))
            elif self.args.task == 'segmentation':
                preds = torch.argmax(outputs, axis=1)
                count += preds.shape[0]
                miou += get_iou(preds,
                                targets,
                                n_classes=self.nclasses['test'],
                                ignore0=self.args.test_set == 'seen'
                                or self.args.test_set == 'unseen')
                if self.args.id:
                    self.vs.predict_id(preds, names, self.args.save_dir)
                if self.args.color:
                    self.vs.predict_color(preds, images, names,
                                          self.args.save_dir)
                if self.args.examine:
                    self.vs.predict_examine(preds, targets, images, names,
                                            self.args.save_dir)

        if self.args.task == 'classification':
            _top1 = top1.avg
            _top5 = top5.avg
            print("Top-1: %.3f, Top-5: %.3f" % (_top1, _top5))
        elif self.args.task == 'segmentation':
            '''
			acc = self.evaluator.Pixel_Accuracy()
			acc_class = self.evaluator.Pixel_Accuracy_Class()
			miou = self.evaluator.Mean_Intersection_over_Union()
			fwiou = self.evaluator.Frequency_Weighted_Intersection_over_Union()
			print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU:{}".format(acc, acc_class, miou, fwiou))
			'''
            print("confidence:{} mIoU:{}".format(self.args.confidence,
                                                 miou / count))
            return miou / count