コード例 #1
0
    def __init__(self, cfg, model, train_dl,  val_dl, criterion, optimizer,  scheduler,  start_epoch, nf):

        self.cfg = cfg
        self.logger = gl.get_value('logger')
        self.train_dl =train_dl
        self.val_dl = val_dl
        self.model = model

        self.train_epoch = start_epoch 
        self.batch_cnt = 0

        use_cuda = torch.cuda.is_available()
        if cfg.MODEL.DEVICE != 'cuda':
            use_cuda = False

        self.device = torch.device('cuda' if use_cuda else 'cpu')
      
        self.mgpu = True if self.device =='cuda' and torch.cuda.device_count() > 1 else False

        self.criterion = criterion

        self.optimizer = optimizer

        self.scheduler = scheduler

        self.start_time = time.time()
        self.epoch_start_time = time.time()

        self.total_loss = AvgerageMeter()
        
        self.global_loss = AvgerageMeter()
        self.drop_loss = AvgerageMeter()
        self.crop_loss = AvgerageMeter()
        self.center_loss = AvgerageMeter()
        
        self.best_metric = 0.0

        self.nf  = nf
        self.logger.info('Trainer Start')

        self.y_true = list()
        self.y_pred = list()
        self.n_count = 0.0
        self.n_correct = 0.0
コード例 #2
0
    def evaluate(self, data_dl):
        self.model.eval()
        y_true_eval = list()
        y_pred_eval = list()
        #n_count_eval = 0.0
        #n_correct_eval = 0.0
        total_loss_eval = AvgerageMeter()
        loss_global_eval = AvgerageMeter()
        loss_local_eval = AvgerageMeter()
        
        with torch.no_grad():
            for batch in data_dl:
                images, targets,meta_infos = self.parse_batch(batch)

                if 'SingleView' in self.cfg.MODEL.NAME  or  'SVBNN' in self.cfg.MODEL.NAME :
                    outputs = self.model(images)

                elif 'SVMeta' in self.cfg.MODEL.NAME:
                    outputs = self.model(images,meta_infos)
                elif 'SVAtt' in self.cfg.MODEL.NAME or  'SVDB' in self.cfg.MODEL.NAME or 'SVreid' in self.cfg.MODEL.NAME :
            
                    outputs = self.model(images,targets)
                else:
                    raise ValueError('unknown model type {self.cfg.MODEL.NAME}')
            
                
                if 'SVreid' in self.cfg.MODEL.NAME:
                    cls_score,_ = outputs
                    loss = self.criterion.crit(cls_score, targets)
        

                    
                    total_loss_eval.update(loss.item(), images.size(0))
                    _, preds = torch.max(outputs[0], 1)
                elif isinstance(outputs,(list,tuple)):
                    
                    # ave prob to count loss
                    loss1 = self.criterion.crit(outputs[0],targets) 
                    loss2 = self.criterion.crit(outputs[1],targets)
                    loss_global_eval.update(loss1.item(), images.size(0))
                    loss_local_eval.update(loss2.item(), images.size(0))
                    loss = loss1 + loss2
                    total_loss_eval.update(loss.item(), images.size(0))
                    
                    outputs = F.softmax(0.5*(outputs[0] +outputs[1]) , dim=-1)
                    _, preds = torch.max(outputs, 1)
                    

                    
                else:
                    #tensor
                    loss = self.criterion(outputs,targets)
                    
                    total_loss_eval.update(loss.item(), images.size(0))
                    _, preds = torch.max(outputs, 1)

                y_true_eval.extend(targets.cpu().numpy())
                y_pred_eval.extend(preds.cpu().numpy())
        

                #n_correct_eval += torch.sum((preds == targets).float()).item()
                #n_count_eval += images.size(0)
             
        metric = {'loss': total_loss_eval.avg,'y_true_eval':y_true_eval, 'y_pred_eval':y_pred_eval}
        if self.cfg.MODEL.LOSS_ATT is True:
            metric['loss_global'] = loss_global_eval.avg
            metric['loss_local'] = loss_local_eval.avg
            
        return metric
コード例 #3
0
class BaseTrainer(object):
    def __init__(self, cfg, model, train_dl,  val_dl, criterion, optimizer,  scheduler,  start_epoch, nf):

        self.cfg = cfg
        self.logger = gl.get_value('logger')
        self.train_dl =train_dl
        self.val_dl = val_dl
        self.model = model

        self.train_epoch = start_epoch 
        self.batch_cnt = 0

        use_cuda = torch.cuda.is_available()
        if cfg.MODEL.DEVICE != 'cuda':
            use_cuda = False

        self.device = torch.device('cuda' if use_cuda else 'cpu')
      
        self.mgpu = True if self.device =='cuda' and torch.cuda.device_count() > 1 else False

        self.criterion = criterion

        self.optimizer = optimizer

        self.scheduler = scheduler

        self.start_time = time.time()
        self.epoch_start_time = time.time()

        self.total_loss = AvgerageMeter()
        
        self.global_loss = AvgerageMeter()
        self.drop_loss = AvgerageMeter()
        self.crop_loss = AvgerageMeter()
        self.center_loss = AvgerageMeter()
        
        self.best_metric = 0.0

        self.nf  = nf
        self.logger.info('Trainer Start')

        self.y_true = list()
        self.y_pred = list()
        self.n_count = 0.0
        self.n_correct = 0.0


    #@(Events.ITERATION_COMPLETED)
    def handle_new_batch(self):
        self.batch_cnt += 1
        if self.batch_cnt % self.cfg.MISC.LOG_PERIOD == 0 or self.batch_cnt == len(self.train_dl):
            elapsed = time.time() - self.start_time

            pstring = (
                "epoch {:2d} | {:4d}/{:4d} batches | ms {:4.02f} | lr {:1.6f}| "
                "acc {:03.03%} | loss {:3.4f}".format(
                    self.train_epoch,
                    self.batch_cnt,
                    len(self.train_dl),
                    elapsed / self.cfg.MISC.LOG_PERIOD,
                    self.scheduler.get_lr()[0],
                    self.n_correct/self.n_count,
                    self.total_loss.avg)
            )
            self.logger.info(f"{pstring}")
            
            if self.cfg.MODEL.LOSS_ATT is True:
               pstring_loss =  ("global loss {:3.4f} | drop loss {:3.4f} | crop loss {:3.4f} | center loss {:3.4f}"  \
                                .format(self.global_loss.avg,self.drop_loss.avg,self.crop_loss.avg,self.center_loss.avg))
               self.logger.info(f"{pstring_loss}")
            if 'SVreid' in self.cfg.MODEL.NAME:
               pstring_loss =  ("global loss {:3.4f} |center loss {:3.4f}"  \
                                .format(self.global_loss.avg,self.center_loss.avg))
               self.logger.info(f"{pstring_loss}")
            self.start_time = time.time()

    


    def step(self, batch):
        self.model.train()
        
        if 'SVBNN' in self.cfg.MODEL.NAME :
            images, targets_a,meta_infos_a = self.parse_batch(batch[0])
            images_bal, targets_b,meta_infos_b = self.parse_batch(batch[1])
            targets = targets_b
            lam =   1.0-((self.train_epoch - 1) / (self.cfg.SOLVER.EPOCHS - 1)) ** 2  # parabolic decay
            
            #lam =   1.0-((self.train_epoch - 1) / (self.cfg.SOLVER.EPOCHS - 1))   # parabolic decay
            
        else:
            images, targets,meta_infos = self.parse_batch(batch)
        
            if self.cfg.MODEL.MIXUP is True:
                
                if self.cfg.MODEL.MIXUP_MODE == 'mixup': #('mixup', 'cutmix')
                    mixfunc = mixup_data
                elif self.cfg.MODEL.MIXUP_MODE == 'cutmix': #('mixup', 'cutmix')
                    mixfunc = cutmix_data
                
                images, targets_a, targets_b, lam = mixfunc(images, targets, alpha = self.cfg.MODEL.MIXUP_ALPHA,prob = self.cfg.MODEL.MIXUP_PROB)
            else:
                lam = 1.0

        if 'SingleView' in self.cfg.MODEL.NAME  :
            outputs = self.model(images)

        elif 'SVMeta' in self.cfg.MODEL.NAME:
            outputs = self.model(images,meta_infos)

        elif 'SVAtt' in self.cfg.MODEL.NAME or  'SVDB' in self.cfg.MODEL.NAME or  'SVreid' in self.cfg.MODEL.NAME :
            
            outputs = self.model(images,targets,lam)
        elif 'SVBNN' in self.cfg.MODEL.NAME :
            
            outputs = self.model(x = images,x_rb = images_bal, alpha = lam)
            
            
        else:
            raise ValueError('unknown model type {self.cfg.MODEL.NAME}')

        if  self.cfg.MODEL.MIXUP is True or 'SVBNN' in self.cfg.MODEL.NAME :
            loss = mixup_criterion(self.criterion, outputs, targets_a, targets_b, lam)
            
        else:
            loss = self.criterion(outputs,targets)
            
        
        
        
        
        if isinstance(loss,(tuple,list)):
            self.total_loss.update(loss[0].item(), images.size(0))    
                    
            self.global_loss.update(loss[1]['global'].item(), images.size(0))   
            
            if 'crop' in loss[1].keys():
                self.crop_loss.update(loss[1]['crop'].item(), images.size(0))   
            if 'drop' in loss[1].keys():                
                self.drop_loss.update(loss[1]['drop'].item(), images.size(0))   
            self.center_loss.update(loss[1]['center'].item(), images.size(0))   
            
        else:

            self.total_loss.update(loss.item(), images.size(0))
        
        
        
        
        self.optimizer.zero_grad()
        if isinstance(outputs,(list,tuple)):
            _, preds = torch.max(outputs[0], 1)
            loss[0].backward()
        else:
            #tensor
            _, preds = torch.max(outputs, 1)
            loss.backward()

        self.y_pred.extend(preds.cpu().numpy())
        self.y_true.extend(targets.cpu().numpy())


        
        #clip_grad_norm_(net.parameters(), 0.5)
    
        self.optimizer.step()
        #self.optimizer.zero_grad()
            
        if self.scheduler is not None:
            self.scheduler.step()          
        
        self.n_correct += torch.sum((preds == targets).float()).item()
        self.n_count += images.size(0)
    
        return
        

    def evaluate(self, data_dl):
        self.model.eval()
        y_true_eval = list()
        y_pred_eval = list()
        #n_count_eval = 0.0
        #n_correct_eval = 0.0
        total_loss_eval = AvgerageMeter()
        loss_global_eval = AvgerageMeter()
        loss_local_eval = AvgerageMeter()
        
        with torch.no_grad():
            for batch in data_dl:
                images, targets,meta_infos = self.parse_batch(batch)

                if 'SingleView' in self.cfg.MODEL.NAME  or  'SVBNN' in self.cfg.MODEL.NAME :
                    outputs = self.model(images)

                elif 'SVMeta' in self.cfg.MODEL.NAME:
                    outputs = self.model(images,meta_infos)
                elif 'SVAtt' in self.cfg.MODEL.NAME or  'SVDB' in self.cfg.MODEL.NAME or 'SVreid' in self.cfg.MODEL.NAME :
            
                    outputs = self.model(images,targets)
                else:
                    raise ValueError('unknown model type {self.cfg.MODEL.NAME}')
            
                
                if 'SVreid' in self.cfg.MODEL.NAME:
                    cls_score,_ = outputs
                    loss = self.criterion.crit(cls_score, targets)
        

                    
                    total_loss_eval.update(loss.item(), images.size(0))
                    _, preds = torch.max(outputs[0], 1)
                elif isinstance(outputs,(list,tuple)):
                    
                    # ave prob to count loss
                    loss1 = self.criterion.crit(outputs[0],targets) 
                    loss2 = self.criterion.crit(outputs[1],targets)
                    loss_global_eval.update(loss1.item(), images.size(0))
                    loss_local_eval.update(loss2.item(), images.size(0))
                    loss = loss1 + loss2
                    total_loss_eval.update(loss.item(), images.size(0))
                    
                    outputs = F.softmax(0.5*(outputs[0] +outputs[1]) , dim=-1)
                    _, preds = torch.max(outputs, 1)
                    

                    
                else:
                    #tensor
                    loss = self.criterion(outputs,targets)
                    
                    total_loss_eval.update(loss.item(), images.size(0))
                    _, preds = torch.max(outputs, 1)

                y_true_eval.extend(targets.cpu().numpy())
                y_pred_eval.extend(preds.cpu().numpy())
        

                #n_correct_eval += torch.sum((preds == targets).float()).item()
                #n_count_eval += images.size(0)
             
        metric = {'loss': total_loss_eval.avg,'y_true_eval':y_true_eval, 'y_pred_eval':y_pred_eval}
        if self.cfg.MODEL.LOSS_ATT is True:
            metric['loss_global'] = loss_global_eval.avg
            metric['loss_local'] = loss_local_eval.avg
            
        return metric

    #@(Events.EPOCH_COMPLETED)
    def handle_new_epoch(self):
        self.batch_cnt = 0
        
        if self.train_epoch % self.cfg.MISC.VALID_EPOCH  == 0:

            self.logger.info("-" * 96)
            val_metric = self.evaluate(self.val_dl)
        
            
            train_acc,valid_acc = self.print_stat(self.y_pred,self.y_true,val_metric['y_pred_eval'],val_metric['y_true_eval'])
            
            #use bal acc as metric instead of loss
            val_ms = float(valid_acc)
            val_loss = val_metric['loss']
            is_best = val_ms > self.best_metric
            
            
            if (is_best and  self.train_epoch>= self.cfg.MISC.SAV_EPOCH) or (self.train_epoch % self.cfg.MISC.SAV_EPOCH == 0):
                dict_sav = {'train_loss':self.total_loss.avg, 'valid_loss':val_loss , 'train_acc':train_acc, 'valid_acc':valid_acc}
                self.save(dict_sav,is_best)

            if is_best :
                self.best_metric = val_ms 
            

            
            self.logger.info("| end of epoch {:3d} | time: {:5.02f}s | val loss {:5.04f} |val ms {:5.03f} | best ms {:5.03f} |".format(
                self.train_epoch, (time.time() - self.epoch_start_time), val_loss,val_ms, self.best_metric)
            )
             
            if self.cfg.MODEL.LOSS_ATT is True:
                self.logger.info("val global {:2.04f}  | val local {:2.04f} | val total {:2.04f}".format(val_metric['loss_global'],val_metric['loss_local'],val_ms))
                
                   
                   
            self.logger.info('Train/val Acc: {:.4f}, {:.4f}'.format(train_acc,valid_acc))


            self.logger.info("-" * 96)

        self.epoch_start_time = time.time()
        self.train_epoch += 1
        self.total_loss.reset()
        self.global_loss.reset() 
        self.drop_loss.reset()
        self.crop_loss.reset() 
        self.center_loss.reset()
        
        
        self.y_true = list()
        self.y_pred = list()
        self.n_count = 0.0
        self.n_correct = 0.0
    def save(self,dict_sav,is_best= False):

        train_loss = dict_sav['train_loss']
        valid_loss = dict_sav['valid_loss'] 
        train_acc  = dict_sav['train_acc']
        valid_acc = dict_sav['valid_acc']

        model_path = osp.join(self.cfg.MISC.OUT_DIR, f"{self.cfg.MODEL.NAME}-Fold-{self.nf}-Epoch-{self.train_epoch}-trainloss-{train_loss:.4f}-loss-{valid_loss:.4f}-trainacc-{train_acc:.4f}-acc-{valid_acc:.4f}.pth")


        torch.save(self.model.state_dict(),model_path)


        if is_best:     

            if self.cfg.DATASETS.K_FOLD ==1:
                best_model_fn = osp.join(self.cfg.MISC.OUT_DIR, f"{self.cfg.MODEL.NAME}-best.pth")
            else:
                best_model_fn = osp.join(self.cfg.MISC.OUT_DIR, f"{self.cfg.MODEL.NAME}-Fold-{self.nf}-best.pth")
            torch.save(self.model.state_dict(),best_model_fn)








    def parse_batch(self,batch):
        if len(batch)==2:
             images, targets = batch
             meta_infos = None
        elif len(batch)==3:
            images, targets,meta_infos = batch
        else:
            raise ValueError('parse batch error')
        images = images.to(self.device)
        targets = targets.to(self.device)
        
        if meta_infos is not None:
            meta_infos = meta_infos.to(self.device)
            
            
        return images,targets,meta_infos

    def print_stat(self,y_pred_tr,y_true_tr,y_pred_vl,y_true_vl):
        ps_tr = self.calc_stat(y_pred_tr,y_true_tr)
        ps_vl = self.calc_stat(y_pred_vl,y_true_vl)

        np.set_printoptions(precision=4)

        self.logger.info(f"Metrics  Epoch: {self.train_epoch}")
        #logger.info(f"Balance Acc 1 2 3 : {bal_acc1:.4f} {bal_acc2:.4f} {bal_acc3:.4f}")


        cm_train = ps_tr['cm']
        cm_valid = ps_vl['cm']
        
        if cm_train.shape[0]<=10: #only n_class<=10 print
            self.logger.info('confusion matix train\n')
            self.logger.info('{}\n'.format(cm_train))
            self.logger.info('confusion matix valid\n')
            self.logger.info('{}\n'.format(cm_valid))
    
    
            self.logger.info("Num All Class Train: {}".format(np.sum(cm_train,axis = 1)))
            self.logger.info("Acc All Class1 Train: {}".format(ps_tr['cls_acc1']))
            self.logger.info("Acc All Class2 Train: {}".format(ps_tr['cls_acc2']))
            self.logger.info("Acc All Class3 Train: {}".format(ps_tr['cls_acc3']))
            
            
        self.logger.info(f"Balance Acc 1 2 3 Train : {ps_tr['bal_acc1']:.4f} {ps_tr['bal_acc2']:.4f} {ps_tr['bal_acc3']:.4f}")
        
        if cm_valid.shape[0]<=10: #only n_class<=10 print
            self.logger.info("Num All Class Valid: {}".format(np.sum(cm_valid,axis = 1)))
            self.logger.info("Acc All Class1 Valid: {}".format(ps_vl['cls_acc1']))
            self.logger.info("Acc All Class2 Valid: {}".format(ps_vl['cls_acc2']))
            self.logger.info("Acc All Class3 Valid: {}".format(ps_vl['cls_acc3']))
        self.logger.info(f"Balance Acc 1 2 3 Valid : {ps_vl['bal_acc1']:.4f} {ps_vl['bal_acc2']:.4f} {ps_vl['bal_acc3']:.4f}")
        self.logger.info(f"Ave Acc : {ps_vl['avg_acc']:.4f}")

        return ps_tr['bal_acc1'],ps_vl['bal_acc1']

    def calc_stat(self, y_pred,y_true):
        y_true = np.array(y_true).astype('int64') 
        y_pred = np.array(y_pred).astype('int64') 
        cm = confusion_matrix(y_true, y_pred)    
        cls_acc1 = cm.diagonal()/(0.0001+np.sum(cm,axis = 1))
        cls_acc2 = cm.diagonal()/(0.0001+np.sum(cm,axis = 0))
        cls_acc3  = cm.diagonal()/(0.0001+ np.sum(cm,axis = 0) +  np.sum(cm,axis = 1)-cm.diagonal()) 

        avg_acc = np.sum(cm.diagonal())/cm.sum()
        bal_acc1 = np.mean(cls_acc1)
        bal_acc2 = np.mean(cls_acc2)
        bal_acc3 = np.mean(cls_acc3)

        pred_stat = {'cm':cm, 'avg_acc':avg_acc, 'bal_acc1':bal_acc1, 'bal_acc2':bal_acc2,'bal_acc3':bal_acc3, 'cls_acc1':cls_acc1,'cls_acc2':cls_acc2,'cls_acc3':cls_acc3}
        return pred_stat
コード例 #4
0
def train(loader,
          net,
          criterion,
          optimizer,
          device,
          epoch,
          scheduler=None,
          net_type='resnet50_c3_locate'):
    net.train(True)
    optimizer.zero_grad()

    if net_type == 'resnet34_c3_locate' or net_type == 'resnet50_c3_locate':

        total_loss, x_loss, y_loss, w_loss, h_loss = AvgerageMeter(
        ), AvgerageMeter(), AvgerageMeter(), AvgerageMeter(), AvgerageMeter()

    elif net_type == 'resnet34_c3_locate_centernet':
        total_loss, score_loss, xy_loss, wh_loss = AvgerageMeter(
        ), AvgerageMeter(), AvgerageMeter(), AvgerageMeter()

    for i, data in enumerate(loader):

        images, boxes = data
        #print(images.shape)
        images = images.to(device)

        if type(boxes) == dict:
            for key, value in boxes.items():
                boxes[key] = value.to(device)
        else:
            boxes = boxes.to(device)

        outputs = net(images)

        #_, preds = torch.max(outputs, 1)
        if net_type == 'resnet34_c3_locate' or net_type == 'resnet50_c3_locate':
            loss, loss_xx, loss_yy, loss_ww, loss_hh = criterion(
                boxes, outputs)

            total_loss.update(loss.item(), images.size(0))
            x_loss.update(loss_xx.item(), images.size(0))
            y_loss.update(loss_yy.item(), images.size(0))
            w_loss.update(loss_ww.item(), images.size(0))
            h_loss.update(loss_hh.item(), images.size(0))

        elif net_type == 'resnet34_c3_locate_centernet':
            loss, loss_score, loss_xy, loss_wh = criterion(boxes, outputs)

            total_loss.update(loss.item(), images.size(0))
            score_loss.update(loss_score.item(), images.size(0))
            xy_loss.update(loss_xy.item(), images.size(0))
            wh_loss.update(loss_wh.item(), images.size(0))

        loss.backward()
        #clip_grad_norm_(net.parameters(), 0.5)
        if scheduler is not None:
            scheduler.step()

        optimizer.step()
        optimizer.zero_grad()

    if net_type == 'resnet34_c3_locate' or net_type == 'resnet50_c3_locate':
        logger.info(f"Train  Epoch: {epoch}, " +
                    f"Lr    : {scheduler.get_lr()[1]:.5f}, " +
                    f"Average Loss: {total_loss.avg:.4f}, " +
                    f"Loss X: {x_loss.avg:.4f}, " +
                    f"Loss Y: {y_loss.avg:.4f}, " +
                    f"Loss W: {w_loss.avg:.4f}, " +
                    f"Loss H: {h_loss.avg:.4f} ")
    elif net_type == 'resnet34_c3_locate_centernet':
        logger.info(f"Train  Epoch: {epoch}, " +
                    f"Lr    : {scheduler.get_lr()[1]:.5f}, " +
                    f"Average Loss: {total_loss.avg:.4f}, " +
                    f"Loss Score: {score_loss.avg:.4f}, " +
                    f"Loss XY: {xy_loss.avg:.4f}, " +
                    f"Loss WH: {wh_loss.avg:.4f}, ")

    return total_loss.avg
コード例 #5
0
def test(loader, net, criterion, device, epoch, net_type='resnet50_c3_locate'):
    net.eval()

    if net_type == 'resnet34_c3_locate' or net_type == 'resnet50_c3_locate':

        total_loss, x_loss, y_loss, w_loss, h_loss = AvgerageMeter(
        ), AvgerageMeter(), AvgerageMeter(), AvgerageMeter(), AvgerageMeter()

    elif net_type == 'resnet34_c3_locate_centernet':
        total_loss, score_loss, xy_loss, wh_loss = AvgerageMeter(
        ), AvgerageMeter(), AvgerageMeter(), AvgerageMeter()

    for _, data in enumerate(loader):
        images, boxes = data
        images = images.to(device)
        if type(boxes) == dict:
            for key, value in boxes.items():
                boxes[key] = value.to(device)
        else:
            boxes = boxes.to(device)

        with torch.no_grad():
            outputs = net(images)
            if net_type == 'resnet34_c3_locate' or net_type == 'resnet50_c3_locate':
                loss, loss_xx, loss_yy, loss_ww, loss_hh = criterion(
                    boxes, outputs)

                total_loss.update(loss.item(), images.size(0))
                x_loss.update(loss_xx.item(), images.size(0))
                y_loss.update(loss_yy.item(), images.size(0))
                w_loss.update(loss_ww.item(), images.size(0))
                h_loss.update(loss_hh.item(), images.size(0))

            elif net_type == 'resnet34_c3_locate_centernet':
                loss, loss_score, loss_xy, loss_wh = criterion(boxes, outputs)

                total_loss.update(loss.item(), images.size(0))
                score_loss.update(loss_score.item(), images.size(0))
                xy_loss.update(loss_xy.item(), images.size(0))
                wh_loss.update(loss_wh.item(), images.size(0))

            #pp = F.softmax(outputs,dim = 1)
            #logger.info("Conf: {}".format(pp[:,1].cpu().numpy()))
            #logger.info(f"labels: {labels}")

    if net_type == 'resnet34_c3_locate' or net_type == 'resnet50_c3_locate':
        logger.info(f"Test  Epoch: {epoch}, " +
                    f"Average Loss: {total_loss.avg:.4f}, " +
                    f"Loss X: {x_loss.avg:.4f}, " +
                    f"Loss Y: {y_loss.avg:.4f}, " +
                    f"Loss W: {w_loss.avg:.4f}, " +
                    f"Loss H: {h_loss.avg:.4f} ")
    elif net_type == 'resnet34_c3_locate_centernet':
        logger.info(f"Test  Epoch: {epoch}, " +
                    f"Average Loss: {total_loss.avg:.4f}, " +
                    f"Loss Score: {score_loss.avg:.4f}, " +
                    f"Loss XY: {xy_loss.avg:.4f}, " +
                    f"Loss WH: {wh_loss.avg:.4f}, ")

    return total_loss.avg
コード例 #6
0
def test_tta(cfg, model, ds, criterion, nf):
    #epoch_loss,epoch_acc,pred_out  = test_tta(cfg, model, valid_loader,criterion,nf)

    #ds, net, criterion, device,epoch = -1,n_tta = 10,n_class = 4
    model.eval()

    device = get_device(cfg)
    logger = gl.get_value('logger')

    if cfg.DATASETS.K_FOLD == 1:
        best_model_fn = osp.join(cfg.MISC.OUT_DIR,
                                 f"{cfg.MODEL.NAME}-best.pth")
    else:
        best_model_fn = osp.join(cfg.MISC.OUT_DIR,
                                 f"{cfg.MODEL.NAME}-Fold-{nf}-best.pth")
    model.load_state_dict(torch.load(best_model_fn))

    n_tta = cfg.MISC.N_TTA
    n_class = cfg.DATASETS.NUM_CLASS

    # in tta, default batch size =1
    n_case = 0.0
    y_true = list()
    y_pred = list()
    total_loss = AvgerageMeter()

    PREDS_ALL = []
    PREDS_ALL_TTA = []
    for idx in tqdm(range(len(ds))):

        #print(images.shape)

        with torch.no_grad():
            #            if cfg.MISC.TTA_MODE in ['mean','mean_softmax']:
            #                pred_sum = torch.zeros((n_class),dtype = torch.float32)
            #            else:
            #                pred_sum = torch.ones((n_class),dtype = torch.float32)
            #
            #for n_t in range(n_tta):

            images, labels, meta_infos = parse_batch(ds[idx])

            y_true.append(labels.item())

            images = images.to(device)
            if meta_infos is not None:
                meta_infos = meta_infos.to(device)

                if meta_infos.dim() == 1:
                    meta_infos = meta_infos[None, ...]

                if images.dim() > 3 and meta_infos.size(0) == 1:

                    meta_infos = meta_infos.repeat(images.size(0), 1)

            labels = labels.to(device)
            labels = labels[None, ...]
            if images.dim() == 3:
                images = images[None, ...]

            if 'SingleView' in cfg.MODEL.NAME or 'SVBNN' in cfg.MODEL.NAME:
                outputs = model(images)
            elif model.mode == 'metasingleview':
                outputs = model(images, meta_infos)

            elif model.mode in ['sv_att', 'sv_db']:

                outputs = model(images, labels)

            if cfg.MISC.ONLY_TEST is False and cfg.DATASETS.NAMES == 'ISIC':
                loss = criterion(outputs, labels)
                total_loss.update(loss.item())

            #if cfg.MODEL.LOSS_TYPE == 'pcs':
            #probs_0 = pcsoftmax(outputs,weight = torch.tensor(cfg.DATASETS.LABEL_W),dim=1)[0].cpu()
            #else:
            if isinstance(outputs, (list, tuple)):
                probs_0 = 0.5 * (F.softmax(outputs[0], dim=1)[0] +
                                 F.softmax(outputs[1], dim=1)[0]).cpu()
            else:

                if 'softmax' in cfg.MISC.TTA_MODE:
                    probs_0 = outputs.cpu().numpy()
                else:
                    probs_0 = F.softmax(outputs, dim=-1).cpu().numpy()

            #save outputs result
            #if cfg.MISC.ONLY_TEST is True:
            PREDS_ALL_TTA.append(outputs.cpu().numpy())

            if cfg.MISC.TTA_MODE in ['mean', 'mean_softmax']:
                pred_sum = np.mean(probs_0, axis=0)
            else:
                pred_sum = np.prod(probs_0, axis=0)
                pred_sum = np.power(pred_sum, 1.0 / n_tta)

            n_case += 1
            probs = np.round_(pred_sum, decimals=4)

            preds = np.argmax(pred_sum)

            y_pred.append(preds)

            if cfg.MISC.ONLY_TEST is False:
                PREDS_ALL.append([*probs, preds, int(labels.item())])
            else:
                PREDS_ALL.append([*probs, preds])

    PREDS_ALL = np.array(PREDS_ALL)
    PREDS_ALL_TTA = np.array(PREDS_ALL_TTA)
    #avg_acc =   (PREDS_ALL[:,-2] == PREDS_ALL[:,-1]).sum()/n_case
    np.set_printoptions(precision=4)

    if cfg.MISC.ONLY_TEST is False:
        pred_stat = calc_stat(y_pred, y_true)
        logger.info(f"Valid  K-fold: {nf}")
        if n_class <= 10:
            logger.info('confusion matix\n')
            cm = pred_stat['cm']
            logger.info('{}\n'.format(cm))
            logger.info("Num All Class: {}".format(np.sum(cm, axis=1)))
            logger.info("Acc All Class1: {}".format(pred_stat['cls_acc1']))
            logger.info("Acc All Class2: {}".format(pred_stat['cls_acc2']))
            logger.info("Acc All Class3: {}".format(pred_stat['cls_acc3']))

        logger.info(
            f"Balance Acc 1 2 3 : {pred_stat['bal_acc1']:.4f} {pred_stat['bal_acc2']:.4f} {pred_stat['bal_acc3']:.4f}"
        )

        logger.info(f"Average Loss: {total_loss.avg:.4f}, " +
                    f"Average Acc:  {pred_stat['avg_acc']}")

        return total_loss.avg, pred_stat['bal_acc1'], PREDS_ALL, PREDS_ALL_TTA
    else:
        return PREDS_ALL, PREDS_ALL_TTA
コード例 #7
0
def test_tta_heatmap(cfg, model, ds, criterion, nf):
    #epoch_loss,epoch_acc,pred_out  = test_tta(cfg, model, valid_loader,criterion,nf)

    #ds, net, criterion, device,epoch = -1,n_tta = 10,n_class = 4

    # cfg.MISC.CALC_HEATMAP is True
    (Path(cfg.MISC.OUT_DIR) / 'heatmap').mkdir(exist_ok=True)

    model.eval()

    device = get_device(cfg)
    logger = gl.get_value('logger')

    if cfg.DATASETS.K_FOLD == 1:
        best_model_fn = osp.join(cfg.MISC.OUT_DIR,
                                 f"{cfg.MODEL.NAME}-best.pth")
    else:
        best_model_fn = osp.join(cfg.MISC.OUT_DIR,
                                 f"{cfg.MODEL.NAME}-Fold-{nf}-best.pth")
    model.load_state_dict(torch.load(best_model_fn))

    n_tta = cfg.MISC.N_TTA
    n_class = cfg.DATASETS.NUM_CLASS

    # in tta, default batch size =1
    n_case = 0.0
    y_true = list()
    y_pred = list()
    total_loss = AvgerageMeter()

    PREDS_ALL = []
    PREDS_ALL_TTA = []
    for idx in tqdm(range(len(ds))):

        #print(images.shape)

        fn = ds.flist[idx]
        img_ori = cv2.imread(fn)
        img_ori = cv2.cvtColor(img_ori, cv2.COLOR_BGR2RGB)
        hh_ori, ww_ori, _ = img_ori.shape

        images, labels, meta_infos, aug_trans = parse_batch(ds[idx])

        y_true.append(labels.item())

        images = images.to(device)
        if meta_infos is not None:
            meta_infos = meta_infos.to(device)

            if meta_infos.dim() == 1:
                meta_infos = meta_infos[None, ...]

            if images.dim() > 3 and meta_infos.size(0) == 1:

                meta_infos = meta_infos.repeat(images.size(0), 1)

        labels = labels.to(device)
        labels = labels[None, ...]
        if images.dim() == 3:
            images = images[None, ...]

        if 'SingleView' in cfg.MODEL.NAME or 'SVBNN' in cfg.MODEL.NAME:
            outputs = model(images)
        elif model.mode == 'metasingleview':
            outputs = model(images, meta_infos)

        elif model.mode in ['sv_att', 'sv_db']:

            outputs = model(images, labels)

        if cfg.MISC.ONLY_TEST is False and cfg.DATASETS.NAMES == 'ISIC':
            loss = criterion(outputs, labels)
            total_loss.update(loss.item())

        #if cfg.MODEL.LOSS_TYPE == 'pcs':
        #probs_0 = pcsoftmax(outputs,weight = torch.tensor(cfg.DATASETS.LABEL_W),dim=1)[0].cpu()
        #else:
        if isinstance(outputs, (list, tuple)):
            probs_0 = 0.5 * (F.softmax(outputs[0], dim=1)[0] +
                             F.softmax(outputs[1], dim=1)[0]).cpu()
        else:

            if 'softmax' in cfg.MISC.TTA_MODE:
                probs_0 = outputs
            else:
                probs_0 = F.softmax(outputs, dim=-1)

        #save outputs result
        #if cfg.MISC.ONLY_TEST is True:
        PREDS_ALL_TTA.append(outputs.detach().cpu().numpy())

        probs = probs_0.detach().cpu().numpy()
        if cfg.MISC.TTA_MODE in ['mean', 'mean_softmax']:
            pred_sum = np.mean(probs, axis=0)
        else:
            pred_sum = np.prod(probs, axis=0)
            pred_sum = np.power(pred_sum, 1.0 / n_tta)

        n_case += 1
        probs = np.round_(pred_sum, decimals=4)

        preds = np.argmax(pred_sum)

        y_pred.append(preds)
        if cfg.MISC.ONLY_TEST is False:
            PREDS_ALL.append([*probs, preds, int(labels.item())])
        else:
            PREDS_ALL.append([*probs, preds])

        # heatmap
        probs_0 = torch.mean(probs_0, dim=0)
        probs_0[preds].backward()

        gradients_IMG = model.get_activations_gradient_IMG()
        #gradients_META = model.get_activations_gradient_META()

        # pool the gradients across the channels
        pooled_gradients_IMG = torch.mean(gradients_IMG, dim=[0, 2, 3])
        #pooled_gradients_LAT = torch.mean(gradients_LAT, dim=[0, 2, 3])

        #pooled_gradients_AP = torch.mean(torch.abs(gradients_AP), dim=[0, 2, 3])
        #pooled_gradients_LAT = torch.mean(torch.abs(gradients_LAT), dim=[0, 2, 3])

        # get the activations of the last  layer
        activations_IMG = model.get_activations_IMG(images).detach()
        #activations_LAT  = model.get_activations_LAT(img).detach()

        # weight the channels by corresponding gradients
        for i in range(pooled_gradients_IMG.shape[0]):
            activations_IMG[:, i, :, :] *= pooled_gradients_IMG[i]
            #activations_LAT[:, i, :, :] *= pooled_gradients_LAT[i]

        # average the channels of the activations
        heatmap_IMG = torch.mean(activations_IMG, dim=1).squeeze().cpu()
        #heatmap_LAT = torch.mean(activations_LAT, dim=1).squeeze().cpu()

        # relu on top of the heatmap
        #heatmap_IMG = np.maximum(heatmap_IMG, 0)
        heatmap_IMG = F.relu(heatmap_IMG)
        #heatmap_LAT = np.maximum(heatmap_LAT, 0)

        # normalize the heatmap
        heatmap_IMG /= torch.max(heatmap_IMG)
        #heatmap_LAT /= torch.max(heatmap_LAT)

        #heatmap_AP *= (heatmap_AP>0.4).float()
        #heatmap_LAT *= (heatmap_LAT>0.4).float()

        hms = heatmap_IMG.cpu().numpy()
        img_w_hm = np.zeros((hh_ori, ww_ori), dtype='float32')
        img_n_hm = np.zeros((hh_ori, ww_ori), dtype='float32') + 0.00001

        # HM
        for hm, trans in zip(hms, aug_trans):
            hm_imin = cv2.resize(hm, (images.shape[3], images.shape[2]))
            img_w_hm += cv2.warpAffine(hm_imin,
                                       trans, (ww_ori, hh_ori),
                                       flags=cv2.INTER_LINEAR,
                                       borderMode=cv2.BORDER_CONSTANT)

            img_n_hm += cv2.warpAffine(np.ones_like(hm_imin),
                                       trans, (ww_ori, hh_ori),
                                       flags=cv2.INTER_LINEAR,
                                       borderMode=cv2.BORDER_CONSTANT)

        hm_out = img_w_hm / img_n_hm

        hm_out = hm_out * ((hm_out > 0.25).astype('float32'))
        hm_out0 = np.uint8(255 * hm_out)

        #hm_out = cv2.applyColorMap(hm_out, cv2.COLORMAP_JET)
        #superimposed_img_AP = hm_out * 0.4 + img_ori[:,:,::-1]
        hm_out = cv2.applyColorMap(
            hm_out0, cv2.COLORMAP_JET) * np.uint8(hm_out0[..., None] > 0.25)
        superimposed_img_AP = hm_out * 0.4 + img_ori[:, :, ::-1]
        #alpha = 0.5
        #superimposed_img_AP = cv2.addWeighted(img_ori, alpha, hm_out, 1 - alpha, 0)
        #superimposed_img_AP = superimposed_img_AP[:,:,::-1]

        label_str = Path(
            fn).stem + ' ' + cfg.DATASETS.DICT_LABEL[preds] + ' prob = ' + str(
                probs[preds])
        #cv2.rectangle(superimposed_img_AP, (0, 0), (200, 40), (0, 0, 0), -1)
        cv2.putText(superimposed_img_AP, label_str, (10, 25),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 0), 2)

        fn_heatmap = Path(cfg.MISC.OUT_DIR) / 'heatmap' / (
            Path(fn).stem + '_' + cfg.DATASETS.DICT_LABEL[preds] + '.jpg')
        cv2.imwrite(str(fn_heatmap), superimposed_img_AP)

    PREDS_ALL = np.array(PREDS_ALL)
    PREDS_ALL_TTA = np.array(PREDS_ALL_TTA)
    #avg_acc =   (PREDS_ALL[:,-2] == PREDS_ALL[:,-1]).sum()/n_case
    np.set_printoptions(precision=4)

    if cfg.MISC.ONLY_TEST is False:
        pred_stat = calc_stat(y_pred, y_true)
        logger.info(f"Valid  K-fold: {nf}")
        if n_class <= 10:
            logger.info('confusion matix\n')
            cm = pred_stat['cm']
            logger.info('{}\n'.format(cm))
            logger.info("Num All Class: {}".format(np.sum(cm, axis=1)))
            logger.info("Acc All Class1: {}".format(pred_stat['cls_acc1']))
            logger.info("Acc All Class2: {}".format(pred_stat['cls_acc2']))
            logger.info("Acc All Class3: {}".format(pred_stat['cls_acc3']))

        logger.info(
            f"Balance Acc 1 2 3 : {pred_stat['bal_acc1']:.4f} {pred_stat['bal_acc2']:.4f} {pred_stat['bal_acc3']:.4f}"
        )

        logger.info(f"Average Loss: {total_loss.avg:.4f}, " +
                    f"Average Acc:  {pred_stat['avg_acc']}")

        return total_loss.avg, pred_stat['bal_acc1'], PREDS_ALL, PREDS_ALL_TTA
    else:
        return PREDS_ALL, PREDS_ALL_TTA