class Trainer(object): def __init__(self, args): self.args = args self.vs = Vs(args.dataset) # Define Dataloader kwargs = {"num_workers": args.workers, "pin_memory": True} ( self.train_loader, self.val_loader, self.test_loader, self.nclass, ) = make_data_loader(args, **kwargs) if self.args.norm == "gn": norm = gn elif self.args.norm == "bn": if self.args.sync_bn: norm = syncbn else: norm = bn elif self.args.norm == "abn": if self.args.sync_bn: norm = syncabn(self.args.gpu_ids) else: norm = abn else: print("Please check the norm.") exit() # Define network if self.args.model == "deeplabv3+": model = DeepLab(args=self.args, num_classes=self.nclass, freeze_bn=args.freeze_bn) elif self.args.model == "deeplabv3": model = DeepLabv3( Norm=args.norm, backbone=args.backbone, output_stride=args.out_stride, num_classes=self.nclass, freeze_bn=args.freeze_bn, ) elif self.args.model == "fpn": model = FPN(args=args, num_classes=self.nclass) # Define Criterion # whether to use class balanced weights if args.use_balanced_weights: classes_weights_path = os.path.join( Path.db_root_dir(args.dataset), args.dataset + "_classes_weights.npy") if os.path.isfile(classes_weights_path): weight = np.load(classes_weights_path) else: weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass) weight = torch.from_numpy(weight.astype(np.float32)) else: weight = None self.criterion = SegmentationLosses( weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type) self.model = model # Define Evaluator self.evaluator = Evaluator(self.nclass) # Using cuda if args.cuda: self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids) patch_replication_callback(self.model) self.model = self.model.cuda() # Resuming checkpoint self.best_pred = 0.0 if args.resume is not None: if not os.path.isfile(args.resume): raise RuntimeError("=> no checkpoint found at '{}'".format( args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint["epoch"] if args.cuda: self.model.module.load_state_dict(checkpoint["state_dict"]) else: self.model.load_state_dict(checkpoint["state_dict"]) self.best_pred = checkpoint["best_pred"] print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint["epoch"])) # Clear start epoch if fine-tuning if args.ft: args.start_epoch = 0 def test(self): self.model.eval() self.args.examine = False tbar = tqdm(self.test_loader, desc="\r") if self.args.color: __image = True else: __image = False for i, sample in enumerate(tbar): images = sample["image"] names = sample["name"] if self.args.cuda: images = images.cuda() with torch.no_grad(): output = self.model(images) preds = output.data.cpu().numpy() preds = np.argmax(preds, axis=1) if __image: images = images.cpu().numpy() if not self.args.color: self.vs.predict_id(preds, names, self.args.save_dir) else: self.vs.predict_color(preds, images, names, self.args.save_dir) def validation(self, epoch): self.model.eval() self.evaluator.reset() tbar = tqdm(self.val_loader, desc="\r") test_loss = 0.0 if self.args.color or self.args.examine: __image = True else: __image = False for i, sample in enumerate(tbar): images, targets = sample["image"], sample["label"] names = sample["name"] if self.args.cuda: images, targets = images.cuda(), targets.cuda() with torch.no_grad(): output = self.model(images) loss = self.criterion(output, targets) test_loss += loss.item() tbar.set_description("Test loss: %.3f" % (test_loss / (i + 1))) preds = output.data.cpu().numpy() targets = targets.cpu().numpy() preds = np.argmax(preds, axis=1) # Add batch sample into evaluator self.evaluator.add_batch(targets, preds) if __image: images = images.cpu().numpy() if self.args.id: self.vs.predict_id(preds, names, self.args.save_dir) if self.args.color: self.vs.predict_color(preds, images, names, self.args.save_dir) if self.args.examine: self.vs.predict_examine(preds, targets, images, names, self.args.save_dir) # Fast test during the training Acc = self.evaluator.Pixel_Accuracy() Acc_class = self.evaluator.Pixel_Accuracy_Class() mIoU = self.evaluator.Mean_Intersection_over_Union() FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union() print("Validation:") # print( # "[Epoch: %d, numImages: %5d]" # % (epoch, i * self.args.batch_size + image.data.shape[0]) # ) print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format( Acc, Acc_class, mIoU, FWIoU)) print("Loss: %.3f" % test_loss)
class Tester(object): def __init__(self, args, verbose=True): self.args = args self.verbose = verbose # Define Dataloader kwargs = {'num_workers': args.workers, 'pin_memory': True} self.train_loader, self.val_loader, self.test_loader, self.nclasses = make_data_loader( args, verbose) if self.args.task == 'segmentation': self.vs = Vs(args.dataset) self.evaluator = Evaluator(self.nclasses['val']) # Define Network model = Model(args, self.nclasses['train'], self.nclasses['test']) # Define Criterion criterion = nn.CrossEntropyLoss(ignore_index=args.ignore_index) self.model, self.criterion = model, criterion # Loading Classifier (SPNet style) if args.call is None or args.cseen is None or args.cunseen is None: raise NotImplementedError( "Classifiers for 'all', 'seen', 'unseen' should be loaded") else: if args.test_set == 'unseen': ctest = args.cunseen elif args.test_set == 'all': ctest = args.call elif args.test_set == 'seen': ctest = args.cseen else: raise RuntimeError("{}".format(args.test_set)) if not os.path.isfile(ctest): raise RuntimeError( "=> no checkpoint for clasifier found at '{}'".format( ctest)) self.model.load_test(ctest) if verbose: print("Classifiers checkpoint successfully loaded from {}, {}". format(args.cseen, ctest)) # Resuming checkpoint self.best_pred = 0.0 if args.resume is not None: if not os.path.isfile(args.resume): raise RuntimeError("{}: No such checkpoint exists".format( args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] if args.cuda: pretrained_dict = checkpoint['state_dict'] model_dict = {} state_dict = self.model.state_dict() for k, v in pretrained_dict.items(): if 'classifier' in k: continue if k in state_dict: model_dict[k] = v state_dict.update(model_dict) self.model.load_state_dict(state_dict) else: print("Please use CUDA") raise NotImplementedError self.best_pred = checkpoint['best_pred'] if verbose: print("Loading {} (epoch {}) successfully done".format( args.resume, checkpoint['epoch'])) # Using CUDA if args.cuda: self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids) self.model = self.model.cuda() if args.ft: args.start_epoch = 0 def test(self): self.model.eval() tbar = tqdm(self.test_loader) logits = [0.0] * self.nclasses['test'] for i, sample in enumerate(tbar): images = sample['image'].cuda() names = sample['name'] falses = torch.from_numpy(np.array([False] * images.shape[0])).cuda() with torch.no_grad(): output = self.model(images, falses) preds_np = output.cpu().numpy() for i in range(preds_np.shape[1]): logits[i] = np.mean(preds_np[:, i, :, :]) preds = torch.argmax(output, axis=1) if self.args.id: self.vs.predict_id(preds, names, self.args.save_dir) if self.args.color: self.vs.predict_color(preds, images, names, self.args.save_dir) #tbar.set_description("{} : {}".format(names[0], self.id2class(preds[0]))) print(logits) def val(self): if self.args.task == 'classification': top1 = AverageMeter('Acc@1', ':6.2f') top5 = AverageMeter('Acc@5', ':6.2f') elif self.args.task == 'segmentation': self.evaluator.reset() if self.args.id or self.args.color or self.args.examine: if not os.path.exists(self.args.save_dir): os.makedirs(self.args.save_dir) self.model.eval() tbar = tqdm(self.test_loader) miou = 0.0 count = 0 for i, sample in enumerate(tbar): images, targets, names = sample['image'].cuda( ), sample['label'].cuda().long(), sample['name'] falses = torch.from_numpy(np.array([False] * images.shape[0])).cuda() with torch.no_grad(): outputs = self.model(images, falses) # Score record if self.args.task == 'classification': acc1, acc5 = accuracy(outputs, targets, topk=(1, 5)) top1.update(acc1[0], images.size(0)) top5.update(acc5[0], images.size(0)) elif self.args.task == 'segmentation': preds = torch.argmax(outputs, axis=1) count += preds.shape[0] miou += get_iou(preds, targets, n_classes=self.nclasses['test'], ignore0=self.args.test_set == 'seen' or self.args.test_set == 'unseen') if self.args.id: self.vs.predict_id(preds, names, self.args.save_dir) if self.args.color: self.vs.predict_color(preds, images, names, self.args.save_dir) if self.args.examine: self.vs.predict_examine(preds, targets, images, names, self.args.save_dir) if self.args.task == 'classification': _top1 = top1.avg _top5 = top5.avg print("Top-1: %.3f, Top-5: %.3f" % (_top1, _top5)) elif self.args.task == 'segmentation': ''' acc = self.evaluator.Pixel_Accuracy() acc_class = self.evaluator.Pixel_Accuracy_Class() miou = self.evaluator.Mean_Intersection_over_Union() fwiou = self.evaluator.Frequency_Weighted_Intersection_over_Union() print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU:{}".format(acc, acc_class, miou, fwiou)) ''' print("confidence:{} mIoU:{}".format(self.args.confidence, miou / count)) return miou / count