def __init__(self, cfg, model, train_dl, val_dl, criterion, optimizer, scheduler, start_epoch, nf): self.cfg = cfg self.logger = gl.get_value('logger') self.train_dl =train_dl self.val_dl = val_dl self.model = model self.train_epoch = start_epoch self.batch_cnt = 0 use_cuda = torch.cuda.is_available() if cfg.MODEL.DEVICE != 'cuda': use_cuda = False self.device = torch.device('cuda' if use_cuda else 'cpu') self.mgpu = True if self.device =='cuda' and torch.cuda.device_count() > 1 else False self.criterion = criterion self.optimizer = optimizer self.scheduler = scheduler self.start_time = time.time() self.epoch_start_time = time.time() self.total_loss = AvgerageMeter() self.global_loss = AvgerageMeter() self.drop_loss = AvgerageMeter() self.crop_loss = AvgerageMeter() self.center_loss = AvgerageMeter() self.best_metric = 0.0 self.nf = nf self.logger.info('Trainer Start') self.y_true = list() self.y_pred = list() self.n_count = 0.0 self.n_correct = 0.0
def evaluate(self, data_dl): self.model.eval() y_true_eval = list() y_pred_eval = list() #n_count_eval = 0.0 #n_correct_eval = 0.0 total_loss_eval = AvgerageMeter() loss_global_eval = AvgerageMeter() loss_local_eval = AvgerageMeter() with torch.no_grad(): for batch in data_dl: images, targets,meta_infos = self.parse_batch(batch) if 'SingleView' in self.cfg.MODEL.NAME or 'SVBNN' in self.cfg.MODEL.NAME : outputs = self.model(images) elif 'SVMeta' in self.cfg.MODEL.NAME: outputs = self.model(images,meta_infos) elif 'SVAtt' in self.cfg.MODEL.NAME or 'SVDB' in self.cfg.MODEL.NAME or 'SVreid' in self.cfg.MODEL.NAME : outputs = self.model(images,targets) else: raise ValueError('unknown model type {self.cfg.MODEL.NAME}') if 'SVreid' in self.cfg.MODEL.NAME: cls_score,_ = outputs loss = self.criterion.crit(cls_score, targets) total_loss_eval.update(loss.item(), images.size(0)) _, preds = torch.max(outputs[0], 1) elif isinstance(outputs,(list,tuple)): # ave prob to count loss loss1 = self.criterion.crit(outputs[0],targets) loss2 = self.criterion.crit(outputs[1],targets) loss_global_eval.update(loss1.item(), images.size(0)) loss_local_eval.update(loss2.item(), images.size(0)) loss = loss1 + loss2 total_loss_eval.update(loss.item(), images.size(0)) outputs = F.softmax(0.5*(outputs[0] +outputs[1]) , dim=-1) _, preds = torch.max(outputs, 1) else: #tensor loss = self.criterion(outputs,targets) total_loss_eval.update(loss.item(), images.size(0)) _, preds = torch.max(outputs, 1) y_true_eval.extend(targets.cpu().numpy()) y_pred_eval.extend(preds.cpu().numpy()) #n_correct_eval += torch.sum((preds == targets).float()).item() #n_count_eval += images.size(0) metric = {'loss': total_loss_eval.avg,'y_true_eval':y_true_eval, 'y_pred_eval':y_pred_eval} if self.cfg.MODEL.LOSS_ATT is True: metric['loss_global'] = loss_global_eval.avg metric['loss_local'] = loss_local_eval.avg return metric
class BaseTrainer(object): def __init__(self, cfg, model, train_dl, val_dl, criterion, optimizer, scheduler, start_epoch, nf): self.cfg = cfg self.logger = gl.get_value('logger') self.train_dl =train_dl self.val_dl = val_dl self.model = model self.train_epoch = start_epoch self.batch_cnt = 0 use_cuda = torch.cuda.is_available() if cfg.MODEL.DEVICE != 'cuda': use_cuda = False self.device = torch.device('cuda' if use_cuda else 'cpu') self.mgpu = True if self.device =='cuda' and torch.cuda.device_count() > 1 else False self.criterion = criterion self.optimizer = optimizer self.scheduler = scheduler self.start_time = time.time() self.epoch_start_time = time.time() self.total_loss = AvgerageMeter() self.global_loss = AvgerageMeter() self.drop_loss = AvgerageMeter() self.crop_loss = AvgerageMeter() self.center_loss = AvgerageMeter() self.best_metric = 0.0 self.nf = nf self.logger.info('Trainer Start') self.y_true = list() self.y_pred = list() self.n_count = 0.0 self.n_correct = 0.0 #@(Events.ITERATION_COMPLETED) def handle_new_batch(self): self.batch_cnt += 1 if self.batch_cnt % self.cfg.MISC.LOG_PERIOD == 0 or self.batch_cnt == len(self.train_dl): elapsed = time.time() - self.start_time pstring = ( "epoch {:2d} | {:4d}/{:4d} batches | ms {:4.02f} | lr {:1.6f}| " "acc {:03.03%} | loss {:3.4f}".format( self.train_epoch, self.batch_cnt, len(self.train_dl), elapsed / self.cfg.MISC.LOG_PERIOD, self.scheduler.get_lr()[0], self.n_correct/self.n_count, self.total_loss.avg) ) self.logger.info(f"{pstring}") if self.cfg.MODEL.LOSS_ATT is True: pstring_loss = ("global loss {:3.4f} | drop loss {:3.4f} | crop loss {:3.4f} | center loss {:3.4f}" \ .format(self.global_loss.avg,self.drop_loss.avg,self.crop_loss.avg,self.center_loss.avg)) self.logger.info(f"{pstring_loss}") if 'SVreid' in self.cfg.MODEL.NAME: pstring_loss = ("global loss {:3.4f} |center loss {:3.4f}" \ .format(self.global_loss.avg,self.center_loss.avg)) self.logger.info(f"{pstring_loss}") self.start_time = time.time() def step(self, batch): self.model.train() if 'SVBNN' in self.cfg.MODEL.NAME : images, targets_a,meta_infos_a = self.parse_batch(batch[0]) images_bal, targets_b,meta_infos_b = self.parse_batch(batch[1]) targets = targets_b lam = 1.0-((self.train_epoch - 1) / (self.cfg.SOLVER.EPOCHS - 1)) ** 2 # parabolic decay #lam = 1.0-((self.train_epoch - 1) / (self.cfg.SOLVER.EPOCHS - 1)) # parabolic decay else: images, targets,meta_infos = self.parse_batch(batch) if self.cfg.MODEL.MIXUP is True: if self.cfg.MODEL.MIXUP_MODE == 'mixup': #('mixup', 'cutmix') mixfunc = mixup_data elif self.cfg.MODEL.MIXUP_MODE == 'cutmix': #('mixup', 'cutmix') mixfunc = cutmix_data images, targets_a, targets_b, lam = mixfunc(images, targets, alpha = self.cfg.MODEL.MIXUP_ALPHA,prob = self.cfg.MODEL.MIXUP_PROB) else: lam = 1.0 if 'SingleView' in self.cfg.MODEL.NAME : outputs = self.model(images) elif 'SVMeta' in self.cfg.MODEL.NAME: outputs = self.model(images,meta_infos) elif 'SVAtt' in self.cfg.MODEL.NAME or 'SVDB' in self.cfg.MODEL.NAME or 'SVreid' in self.cfg.MODEL.NAME : outputs = self.model(images,targets,lam) elif 'SVBNN' in self.cfg.MODEL.NAME : outputs = self.model(x = images,x_rb = images_bal, alpha = lam) else: raise ValueError('unknown model type {self.cfg.MODEL.NAME}') if self.cfg.MODEL.MIXUP is True or 'SVBNN' in self.cfg.MODEL.NAME : loss = mixup_criterion(self.criterion, outputs, targets_a, targets_b, lam) else: loss = self.criterion(outputs,targets) if isinstance(loss,(tuple,list)): self.total_loss.update(loss[0].item(), images.size(0)) self.global_loss.update(loss[1]['global'].item(), images.size(0)) if 'crop' in loss[1].keys(): self.crop_loss.update(loss[1]['crop'].item(), images.size(0)) if 'drop' in loss[1].keys(): self.drop_loss.update(loss[1]['drop'].item(), images.size(0)) self.center_loss.update(loss[1]['center'].item(), images.size(0)) else: self.total_loss.update(loss.item(), images.size(0)) self.optimizer.zero_grad() if isinstance(outputs,(list,tuple)): _, preds = torch.max(outputs[0], 1) loss[0].backward() else: #tensor _, preds = torch.max(outputs, 1) loss.backward() self.y_pred.extend(preds.cpu().numpy()) self.y_true.extend(targets.cpu().numpy()) #clip_grad_norm_(net.parameters(), 0.5) self.optimizer.step() #self.optimizer.zero_grad() if self.scheduler is not None: self.scheduler.step() self.n_correct += torch.sum((preds == targets).float()).item() self.n_count += images.size(0) return def evaluate(self, data_dl): self.model.eval() y_true_eval = list() y_pred_eval = list() #n_count_eval = 0.0 #n_correct_eval = 0.0 total_loss_eval = AvgerageMeter() loss_global_eval = AvgerageMeter() loss_local_eval = AvgerageMeter() with torch.no_grad(): for batch in data_dl: images, targets,meta_infos = self.parse_batch(batch) if 'SingleView' in self.cfg.MODEL.NAME or 'SVBNN' in self.cfg.MODEL.NAME : outputs = self.model(images) elif 'SVMeta' in self.cfg.MODEL.NAME: outputs = self.model(images,meta_infos) elif 'SVAtt' in self.cfg.MODEL.NAME or 'SVDB' in self.cfg.MODEL.NAME or 'SVreid' in self.cfg.MODEL.NAME : outputs = self.model(images,targets) else: raise ValueError('unknown model type {self.cfg.MODEL.NAME}') if 'SVreid' in self.cfg.MODEL.NAME: cls_score,_ = outputs loss = self.criterion.crit(cls_score, targets) total_loss_eval.update(loss.item(), images.size(0)) _, preds = torch.max(outputs[0], 1) elif isinstance(outputs,(list,tuple)): # ave prob to count loss loss1 = self.criterion.crit(outputs[0],targets) loss2 = self.criterion.crit(outputs[1],targets) loss_global_eval.update(loss1.item(), images.size(0)) loss_local_eval.update(loss2.item(), images.size(0)) loss = loss1 + loss2 total_loss_eval.update(loss.item(), images.size(0)) outputs = F.softmax(0.5*(outputs[0] +outputs[1]) , dim=-1) _, preds = torch.max(outputs, 1) else: #tensor loss = self.criterion(outputs,targets) total_loss_eval.update(loss.item(), images.size(0)) _, preds = torch.max(outputs, 1) y_true_eval.extend(targets.cpu().numpy()) y_pred_eval.extend(preds.cpu().numpy()) #n_correct_eval += torch.sum((preds == targets).float()).item() #n_count_eval += images.size(0) metric = {'loss': total_loss_eval.avg,'y_true_eval':y_true_eval, 'y_pred_eval':y_pred_eval} if self.cfg.MODEL.LOSS_ATT is True: metric['loss_global'] = loss_global_eval.avg metric['loss_local'] = loss_local_eval.avg return metric #@(Events.EPOCH_COMPLETED) def handle_new_epoch(self): self.batch_cnt = 0 if self.train_epoch % self.cfg.MISC.VALID_EPOCH == 0: self.logger.info("-" * 96) val_metric = self.evaluate(self.val_dl) train_acc,valid_acc = self.print_stat(self.y_pred,self.y_true,val_metric['y_pred_eval'],val_metric['y_true_eval']) #use bal acc as metric instead of loss val_ms = float(valid_acc) val_loss = val_metric['loss'] is_best = val_ms > self.best_metric if (is_best and self.train_epoch>= self.cfg.MISC.SAV_EPOCH) or (self.train_epoch % self.cfg.MISC.SAV_EPOCH == 0): dict_sav = {'train_loss':self.total_loss.avg, 'valid_loss':val_loss , 'train_acc':train_acc, 'valid_acc':valid_acc} self.save(dict_sav,is_best) if is_best : self.best_metric = val_ms self.logger.info("| end of epoch {:3d} | time: {:5.02f}s | val loss {:5.04f} |val ms {:5.03f} | best ms {:5.03f} |".format( self.train_epoch, (time.time() - self.epoch_start_time), val_loss,val_ms, self.best_metric) ) if self.cfg.MODEL.LOSS_ATT is True: self.logger.info("val global {:2.04f} | val local {:2.04f} | val total {:2.04f}".format(val_metric['loss_global'],val_metric['loss_local'],val_ms)) self.logger.info('Train/val Acc: {:.4f}, {:.4f}'.format(train_acc,valid_acc)) self.logger.info("-" * 96) self.epoch_start_time = time.time() self.train_epoch += 1 self.total_loss.reset() self.global_loss.reset() self.drop_loss.reset() self.crop_loss.reset() self.center_loss.reset() self.y_true = list() self.y_pred = list() self.n_count = 0.0 self.n_correct = 0.0 def save(self,dict_sav,is_best= False): train_loss = dict_sav['train_loss'] valid_loss = dict_sav['valid_loss'] train_acc = dict_sav['train_acc'] valid_acc = dict_sav['valid_acc'] model_path = osp.join(self.cfg.MISC.OUT_DIR, f"{self.cfg.MODEL.NAME}-Fold-{self.nf}-Epoch-{self.train_epoch}-trainloss-{train_loss:.4f}-loss-{valid_loss:.4f}-trainacc-{train_acc:.4f}-acc-{valid_acc:.4f}.pth") torch.save(self.model.state_dict(),model_path) if is_best: if self.cfg.DATASETS.K_FOLD ==1: best_model_fn = osp.join(self.cfg.MISC.OUT_DIR, f"{self.cfg.MODEL.NAME}-best.pth") else: best_model_fn = osp.join(self.cfg.MISC.OUT_DIR, f"{self.cfg.MODEL.NAME}-Fold-{self.nf}-best.pth") torch.save(self.model.state_dict(),best_model_fn) def parse_batch(self,batch): if len(batch)==2: images, targets = batch meta_infos = None elif len(batch)==3: images, targets,meta_infos = batch else: raise ValueError('parse batch error') images = images.to(self.device) targets = targets.to(self.device) if meta_infos is not None: meta_infos = meta_infos.to(self.device) return images,targets,meta_infos def print_stat(self,y_pred_tr,y_true_tr,y_pred_vl,y_true_vl): ps_tr = self.calc_stat(y_pred_tr,y_true_tr) ps_vl = self.calc_stat(y_pred_vl,y_true_vl) np.set_printoptions(precision=4) self.logger.info(f"Metrics Epoch: {self.train_epoch}") #logger.info(f"Balance Acc 1 2 3 : {bal_acc1:.4f} {bal_acc2:.4f} {bal_acc3:.4f}") cm_train = ps_tr['cm'] cm_valid = ps_vl['cm'] if cm_train.shape[0]<=10: #only n_class<=10 print self.logger.info('confusion matix train\n') self.logger.info('{}\n'.format(cm_train)) self.logger.info('confusion matix valid\n') self.logger.info('{}\n'.format(cm_valid)) self.logger.info("Num All Class Train: {}".format(np.sum(cm_train,axis = 1))) self.logger.info("Acc All Class1 Train: {}".format(ps_tr['cls_acc1'])) self.logger.info("Acc All Class2 Train: {}".format(ps_tr['cls_acc2'])) self.logger.info("Acc All Class3 Train: {}".format(ps_tr['cls_acc3'])) self.logger.info(f"Balance Acc 1 2 3 Train : {ps_tr['bal_acc1']:.4f} {ps_tr['bal_acc2']:.4f} {ps_tr['bal_acc3']:.4f}") if cm_valid.shape[0]<=10: #only n_class<=10 print self.logger.info("Num All Class Valid: {}".format(np.sum(cm_valid,axis = 1))) self.logger.info("Acc All Class1 Valid: {}".format(ps_vl['cls_acc1'])) self.logger.info("Acc All Class2 Valid: {}".format(ps_vl['cls_acc2'])) self.logger.info("Acc All Class3 Valid: {}".format(ps_vl['cls_acc3'])) self.logger.info(f"Balance Acc 1 2 3 Valid : {ps_vl['bal_acc1']:.4f} {ps_vl['bal_acc2']:.4f} {ps_vl['bal_acc3']:.4f}") self.logger.info(f"Ave Acc : {ps_vl['avg_acc']:.4f}") return ps_tr['bal_acc1'],ps_vl['bal_acc1'] def calc_stat(self, y_pred,y_true): y_true = np.array(y_true).astype('int64') y_pred = np.array(y_pred).astype('int64') cm = confusion_matrix(y_true, y_pred) cls_acc1 = cm.diagonal()/(0.0001+np.sum(cm,axis = 1)) cls_acc2 = cm.diagonal()/(0.0001+np.sum(cm,axis = 0)) cls_acc3 = cm.diagonal()/(0.0001+ np.sum(cm,axis = 0) + np.sum(cm,axis = 1)-cm.diagonal()) avg_acc = np.sum(cm.diagonal())/cm.sum() bal_acc1 = np.mean(cls_acc1) bal_acc2 = np.mean(cls_acc2) bal_acc3 = np.mean(cls_acc3) pred_stat = {'cm':cm, 'avg_acc':avg_acc, 'bal_acc1':bal_acc1, 'bal_acc2':bal_acc2,'bal_acc3':bal_acc3, 'cls_acc1':cls_acc1,'cls_acc2':cls_acc2,'cls_acc3':cls_acc3} return pred_stat
def train(loader, net, criterion, optimizer, device, epoch, scheduler=None, net_type='resnet50_c3_locate'): net.train(True) optimizer.zero_grad() if net_type == 'resnet34_c3_locate' or net_type == 'resnet50_c3_locate': total_loss, x_loss, y_loss, w_loss, h_loss = AvgerageMeter( ), AvgerageMeter(), AvgerageMeter(), AvgerageMeter(), AvgerageMeter() elif net_type == 'resnet34_c3_locate_centernet': total_loss, score_loss, xy_loss, wh_loss = AvgerageMeter( ), AvgerageMeter(), AvgerageMeter(), AvgerageMeter() for i, data in enumerate(loader): images, boxes = data #print(images.shape) images = images.to(device) if type(boxes) == dict: for key, value in boxes.items(): boxes[key] = value.to(device) else: boxes = boxes.to(device) outputs = net(images) #_, preds = torch.max(outputs, 1) if net_type == 'resnet34_c3_locate' or net_type == 'resnet50_c3_locate': loss, loss_xx, loss_yy, loss_ww, loss_hh = criterion( boxes, outputs) total_loss.update(loss.item(), images.size(0)) x_loss.update(loss_xx.item(), images.size(0)) y_loss.update(loss_yy.item(), images.size(0)) w_loss.update(loss_ww.item(), images.size(0)) h_loss.update(loss_hh.item(), images.size(0)) elif net_type == 'resnet34_c3_locate_centernet': loss, loss_score, loss_xy, loss_wh = criterion(boxes, outputs) total_loss.update(loss.item(), images.size(0)) score_loss.update(loss_score.item(), images.size(0)) xy_loss.update(loss_xy.item(), images.size(0)) wh_loss.update(loss_wh.item(), images.size(0)) loss.backward() #clip_grad_norm_(net.parameters(), 0.5) if scheduler is not None: scheduler.step() optimizer.step() optimizer.zero_grad() if net_type == 'resnet34_c3_locate' or net_type == 'resnet50_c3_locate': logger.info(f"Train Epoch: {epoch}, " + f"Lr : {scheduler.get_lr()[1]:.5f}, " + f"Average Loss: {total_loss.avg:.4f}, " + f"Loss X: {x_loss.avg:.4f}, " + f"Loss Y: {y_loss.avg:.4f}, " + f"Loss W: {w_loss.avg:.4f}, " + f"Loss H: {h_loss.avg:.4f} ") elif net_type == 'resnet34_c3_locate_centernet': logger.info(f"Train Epoch: {epoch}, " + f"Lr : {scheduler.get_lr()[1]:.5f}, " + f"Average Loss: {total_loss.avg:.4f}, " + f"Loss Score: {score_loss.avg:.4f}, " + f"Loss XY: {xy_loss.avg:.4f}, " + f"Loss WH: {wh_loss.avg:.4f}, ") return total_loss.avg
def test(loader, net, criterion, device, epoch, net_type='resnet50_c3_locate'): net.eval() if net_type == 'resnet34_c3_locate' or net_type == 'resnet50_c3_locate': total_loss, x_loss, y_loss, w_loss, h_loss = AvgerageMeter( ), AvgerageMeter(), AvgerageMeter(), AvgerageMeter(), AvgerageMeter() elif net_type == 'resnet34_c3_locate_centernet': total_loss, score_loss, xy_loss, wh_loss = AvgerageMeter( ), AvgerageMeter(), AvgerageMeter(), AvgerageMeter() for _, data in enumerate(loader): images, boxes = data images = images.to(device) if type(boxes) == dict: for key, value in boxes.items(): boxes[key] = value.to(device) else: boxes = boxes.to(device) with torch.no_grad(): outputs = net(images) if net_type == 'resnet34_c3_locate' or net_type == 'resnet50_c3_locate': loss, loss_xx, loss_yy, loss_ww, loss_hh = criterion( boxes, outputs) total_loss.update(loss.item(), images.size(0)) x_loss.update(loss_xx.item(), images.size(0)) y_loss.update(loss_yy.item(), images.size(0)) w_loss.update(loss_ww.item(), images.size(0)) h_loss.update(loss_hh.item(), images.size(0)) elif net_type == 'resnet34_c3_locate_centernet': loss, loss_score, loss_xy, loss_wh = criterion(boxes, outputs) total_loss.update(loss.item(), images.size(0)) score_loss.update(loss_score.item(), images.size(0)) xy_loss.update(loss_xy.item(), images.size(0)) wh_loss.update(loss_wh.item(), images.size(0)) #pp = F.softmax(outputs,dim = 1) #logger.info("Conf: {}".format(pp[:,1].cpu().numpy())) #logger.info(f"labels: {labels}") if net_type == 'resnet34_c3_locate' or net_type == 'resnet50_c3_locate': logger.info(f"Test Epoch: {epoch}, " + f"Average Loss: {total_loss.avg:.4f}, " + f"Loss X: {x_loss.avg:.4f}, " + f"Loss Y: {y_loss.avg:.4f}, " + f"Loss W: {w_loss.avg:.4f}, " + f"Loss H: {h_loss.avg:.4f} ") elif net_type == 'resnet34_c3_locate_centernet': logger.info(f"Test Epoch: {epoch}, " + f"Average Loss: {total_loss.avg:.4f}, " + f"Loss Score: {score_loss.avg:.4f}, " + f"Loss XY: {xy_loss.avg:.4f}, " + f"Loss WH: {wh_loss.avg:.4f}, ") return total_loss.avg
def test_tta(cfg, model, ds, criterion, nf): #epoch_loss,epoch_acc,pred_out = test_tta(cfg, model, valid_loader,criterion,nf) #ds, net, criterion, device,epoch = -1,n_tta = 10,n_class = 4 model.eval() device = get_device(cfg) logger = gl.get_value('logger') if cfg.DATASETS.K_FOLD == 1: best_model_fn = osp.join(cfg.MISC.OUT_DIR, f"{cfg.MODEL.NAME}-best.pth") else: best_model_fn = osp.join(cfg.MISC.OUT_DIR, f"{cfg.MODEL.NAME}-Fold-{nf}-best.pth") model.load_state_dict(torch.load(best_model_fn)) n_tta = cfg.MISC.N_TTA n_class = cfg.DATASETS.NUM_CLASS # in tta, default batch size =1 n_case = 0.0 y_true = list() y_pred = list() total_loss = AvgerageMeter() PREDS_ALL = [] PREDS_ALL_TTA = [] for idx in tqdm(range(len(ds))): #print(images.shape) with torch.no_grad(): # if cfg.MISC.TTA_MODE in ['mean','mean_softmax']: # pred_sum = torch.zeros((n_class),dtype = torch.float32) # else: # pred_sum = torch.ones((n_class),dtype = torch.float32) # #for n_t in range(n_tta): images, labels, meta_infos = parse_batch(ds[idx]) y_true.append(labels.item()) images = images.to(device) if meta_infos is not None: meta_infos = meta_infos.to(device) if meta_infos.dim() == 1: meta_infos = meta_infos[None, ...] if images.dim() > 3 and meta_infos.size(0) == 1: meta_infos = meta_infos.repeat(images.size(0), 1) labels = labels.to(device) labels = labels[None, ...] if images.dim() == 3: images = images[None, ...] if 'SingleView' in cfg.MODEL.NAME or 'SVBNN' in cfg.MODEL.NAME: outputs = model(images) elif model.mode == 'metasingleview': outputs = model(images, meta_infos) elif model.mode in ['sv_att', 'sv_db']: outputs = model(images, labels) if cfg.MISC.ONLY_TEST is False and cfg.DATASETS.NAMES == 'ISIC': loss = criterion(outputs, labels) total_loss.update(loss.item()) #if cfg.MODEL.LOSS_TYPE == 'pcs': #probs_0 = pcsoftmax(outputs,weight = torch.tensor(cfg.DATASETS.LABEL_W),dim=1)[0].cpu() #else: if isinstance(outputs, (list, tuple)): probs_0 = 0.5 * (F.softmax(outputs[0], dim=1)[0] + F.softmax(outputs[1], dim=1)[0]).cpu() else: if 'softmax' in cfg.MISC.TTA_MODE: probs_0 = outputs.cpu().numpy() else: probs_0 = F.softmax(outputs, dim=-1).cpu().numpy() #save outputs result #if cfg.MISC.ONLY_TEST is True: PREDS_ALL_TTA.append(outputs.cpu().numpy()) if cfg.MISC.TTA_MODE in ['mean', 'mean_softmax']: pred_sum = np.mean(probs_0, axis=0) else: pred_sum = np.prod(probs_0, axis=0) pred_sum = np.power(pred_sum, 1.0 / n_tta) n_case += 1 probs = np.round_(pred_sum, decimals=4) preds = np.argmax(pred_sum) y_pred.append(preds) if cfg.MISC.ONLY_TEST is False: PREDS_ALL.append([*probs, preds, int(labels.item())]) else: PREDS_ALL.append([*probs, preds]) PREDS_ALL = np.array(PREDS_ALL) PREDS_ALL_TTA = np.array(PREDS_ALL_TTA) #avg_acc = (PREDS_ALL[:,-2] == PREDS_ALL[:,-1]).sum()/n_case np.set_printoptions(precision=4) if cfg.MISC.ONLY_TEST is False: pred_stat = calc_stat(y_pred, y_true) logger.info(f"Valid K-fold: {nf}") if n_class <= 10: logger.info('confusion matix\n') cm = pred_stat['cm'] logger.info('{}\n'.format(cm)) logger.info("Num All Class: {}".format(np.sum(cm, axis=1))) logger.info("Acc All Class1: {}".format(pred_stat['cls_acc1'])) logger.info("Acc All Class2: {}".format(pred_stat['cls_acc2'])) logger.info("Acc All Class3: {}".format(pred_stat['cls_acc3'])) logger.info( f"Balance Acc 1 2 3 : {pred_stat['bal_acc1']:.4f} {pred_stat['bal_acc2']:.4f} {pred_stat['bal_acc3']:.4f}" ) logger.info(f"Average Loss: {total_loss.avg:.4f}, " + f"Average Acc: {pred_stat['avg_acc']}") return total_loss.avg, pred_stat['bal_acc1'], PREDS_ALL, PREDS_ALL_TTA else: return PREDS_ALL, PREDS_ALL_TTA
def test_tta_heatmap(cfg, model, ds, criterion, nf): #epoch_loss,epoch_acc,pred_out = test_tta(cfg, model, valid_loader,criterion,nf) #ds, net, criterion, device,epoch = -1,n_tta = 10,n_class = 4 # cfg.MISC.CALC_HEATMAP is True (Path(cfg.MISC.OUT_DIR) / 'heatmap').mkdir(exist_ok=True) model.eval() device = get_device(cfg) logger = gl.get_value('logger') if cfg.DATASETS.K_FOLD == 1: best_model_fn = osp.join(cfg.MISC.OUT_DIR, f"{cfg.MODEL.NAME}-best.pth") else: best_model_fn = osp.join(cfg.MISC.OUT_DIR, f"{cfg.MODEL.NAME}-Fold-{nf}-best.pth") model.load_state_dict(torch.load(best_model_fn)) n_tta = cfg.MISC.N_TTA n_class = cfg.DATASETS.NUM_CLASS # in tta, default batch size =1 n_case = 0.0 y_true = list() y_pred = list() total_loss = AvgerageMeter() PREDS_ALL = [] PREDS_ALL_TTA = [] for idx in tqdm(range(len(ds))): #print(images.shape) fn = ds.flist[idx] img_ori = cv2.imread(fn) img_ori = cv2.cvtColor(img_ori, cv2.COLOR_BGR2RGB) hh_ori, ww_ori, _ = img_ori.shape images, labels, meta_infos, aug_trans = parse_batch(ds[idx]) y_true.append(labels.item()) images = images.to(device) if meta_infos is not None: meta_infos = meta_infos.to(device) if meta_infos.dim() == 1: meta_infos = meta_infos[None, ...] if images.dim() > 3 and meta_infos.size(0) == 1: meta_infos = meta_infos.repeat(images.size(0), 1) labels = labels.to(device) labels = labels[None, ...] if images.dim() == 3: images = images[None, ...] if 'SingleView' in cfg.MODEL.NAME or 'SVBNN' in cfg.MODEL.NAME: outputs = model(images) elif model.mode == 'metasingleview': outputs = model(images, meta_infos) elif model.mode in ['sv_att', 'sv_db']: outputs = model(images, labels) if cfg.MISC.ONLY_TEST is False and cfg.DATASETS.NAMES == 'ISIC': loss = criterion(outputs, labels) total_loss.update(loss.item()) #if cfg.MODEL.LOSS_TYPE == 'pcs': #probs_0 = pcsoftmax(outputs,weight = torch.tensor(cfg.DATASETS.LABEL_W),dim=1)[0].cpu() #else: if isinstance(outputs, (list, tuple)): probs_0 = 0.5 * (F.softmax(outputs[0], dim=1)[0] + F.softmax(outputs[1], dim=1)[0]).cpu() else: if 'softmax' in cfg.MISC.TTA_MODE: probs_0 = outputs else: probs_0 = F.softmax(outputs, dim=-1) #save outputs result #if cfg.MISC.ONLY_TEST is True: PREDS_ALL_TTA.append(outputs.detach().cpu().numpy()) probs = probs_0.detach().cpu().numpy() if cfg.MISC.TTA_MODE in ['mean', 'mean_softmax']: pred_sum = np.mean(probs, axis=0) else: pred_sum = np.prod(probs, axis=0) pred_sum = np.power(pred_sum, 1.0 / n_tta) n_case += 1 probs = np.round_(pred_sum, decimals=4) preds = np.argmax(pred_sum) y_pred.append(preds) if cfg.MISC.ONLY_TEST is False: PREDS_ALL.append([*probs, preds, int(labels.item())]) else: PREDS_ALL.append([*probs, preds]) # heatmap probs_0 = torch.mean(probs_0, dim=0) probs_0[preds].backward() gradients_IMG = model.get_activations_gradient_IMG() #gradients_META = model.get_activations_gradient_META() # pool the gradients across the channels pooled_gradients_IMG = torch.mean(gradients_IMG, dim=[0, 2, 3]) #pooled_gradients_LAT = torch.mean(gradients_LAT, dim=[0, 2, 3]) #pooled_gradients_AP = torch.mean(torch.abs(gradients_AP), dim=[0, 2, 3]) #pooled_gradients_LAT = torch.mean(torch.abs(gradients_LAT), dim=[0, 2, 3]) # get the activations of the last layer activations_IMG = model.get_activations_IMG(images).detach() #activations_LAT = model.get_activations_LAT(img).detach() # weight the channels by corresponding gradients for i in range(pooled_gradients_IMG.shape[0]): activations_IMG[:, i, :, :] *= pooled_gradients_IMG[i] #activations_LAT[:, i, :, :] *= pooled_gradients_LAT[i] # average the channels of the activations heatmap_IMG = torch.mean(activations_IMG, dim=1).squeeze().cpu() #heatmap_LAT = torch.mean(activations_LAT, dim=1).squeeze().cpu() # relu on top of the heatmap #heatmap_IMG = np.maximum(heatmap_IMG, 0) heatmap_IMG = F.relu(heatmap_IMG) #heatmap_LAT = np.maximum(heatmap_LAT, 0) # normalize the heatmap heatmap_IMG /= torch.max(heatmap_IMG) #heatmap_LAT /= torch.max(heatmap_LAT) #heatmap_AP *= (heatmap_AP>0.4).float() #heatmap_LAT *= (heatmap_LAT>0.4).float() hms = heatmap_IMG.cpu().numpy() img_w_hm = np.zeros((hh_ori, ww_ori), dtype='float32') img_n_hm = np.zeros((hh_ori, ww_ori), dtype='float32') + 0.00001 # HM for hm, trans in zip(hms, aug_trans): hm_imin = cv2.resize(hm, (images.shape[3], images.shape[2])) img_w_hm += cv2.warpAffine(hm_imin, trans, (ww_ori, hh_ori), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT) img_n_hm += cv2.warpAffine(np.ones_like(hm_imin), trans, (ww_ori, hh_ori), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT) hm_out = img_w_hm / img_n_hm hm_out = hm_out * ((hm_out > 0.25).astype('float32')) hm_out0 = np.uint8(255 * hm_out) #hm_out = cv2.applyColorMap(hm_out, cv2.COLORMAP_JET) #superimposed_img_AP = hm_out * 0.4 + img_ori[:,:,::-1] hm_out = cv2.applyColorMap( hm_out0, cv2.COLORMAP_JET) * np.uint8(hm_out0[..., None] > 0.25) superimposed_img_AP = hm_out * 0.4 + img_ori[:, :, ::-1] #alpha = 0.5 #superimposed_img_AP = cv2.addWeighted(img_ori, alpha, hm_out, 1 - alpha, 0) #superimposed_img_AP = superimposed_img_AP[:,:,::-1] label_str = Path( fn).stem + ' ' + cfg.DATASETS.DICT_LABEL[preds] + ' prob = ' + str( probs[preds]) #cv2.rectangle(superimposed_img_AP, (0, 0), (200, 40), (0, 0, 0), -1) cv2.putText(superimposed_img_AP, label_str, (10, 25), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 0), 2) fn_heatmap = Path(cfg.MISC.OUT_DIR) / 'heatmap' / ( Path(fn).stem + '_' + cfg.DATASETS.DICT_LABEL[preds] + '.jpg') cv2.imwrite(str(fn_heatmap), superimposed_img_AP) PREDS_ALL = np.array(PREDS_ALL) PREDS_ALL_TTA = np.array(PREDS_ALL_TTA) #avg_acc = (PREDS_ALL[:,-2] == PREDS_ALL[:,-1]).sum()/n_case np.set_printoptions(precision=4) if cfg.MISC.ONLY_TEST is False: pred_stat = calc_stat(y_pred, y_true) logger.info(f"Valid K-fold: {nf}") if n_class <= 10: logger.info('confusion matix\n') cm = pred_stat['cm'] logger.info('{}\n'.format(cm)) logger.info("Num All Class: {}".format(np.sum(cm, axis=1))) logger.info("Acc All Class1: {}".format(pred_stat['cls_acc1'])) logger.info("Acc All Class2: {}".format(pred_stat['cls_acc2'])) logger.info("Acc All Class3: {}".format(pred_stat['cls_acc3'])) logger.info( f"Balance Acc 1 2 3 : {pred_stat['bal_acc1']:.4f} {pred_stat['bal_acc2']:.4f} {pred_stat['bal_acc3']:.4f}" ) logger.info(f"Average Loss: {total_loss.avg:.4f}, " + f"Average Acc: {pred_stat['avg_acc']}") return total_loss.avg, pred_stat['bal_acc1'], PREDS_ALL, PREDS_ALL_TTA else: return PREDS_ALL, PREDS_ALL_TTA