def train_epoch(self, epoch, phase):
        cls_loss_ = AverageMeter()
        bbox_loss_ = AverageMeter()
        total_loss_ = AverageMeter()
        accuracy_ = AverageMeter()
        if phase == 'train':
            self.model.train()
        else:
            self.model.eval()

        for batch_idx, sample in enumerate(self.dataloaders[phase]):
            data = sample['input_img']
            gt_cls = sample['cls_target']
            gt_bbox = sample['bbox_target']
            data, gt_cls, gt_bbox = data.to(self.device), gt_cls.to(
                self.device), gt_bbox.to(self.device).float()
            
            self.optimizer.zero_grad()
            with torch.set_grad_enabled(phase == 'train'):
                pred_cls, pred_bbox = self.model(data)
                
                # compute the cls loss and bbox loss and weighted them together
                cls_loss = self.lossfn.cls_loss(gt_cls, pred_cls)
                bbox_loss = self.lossfn.box_loss(gt_cls, gt_bbox, pred_bbox)
                total_loss = cls_loss + 5*bbox_loss
                
                # compute clssification accuracy
                accuracy = self.compute_accuracy(pred_cls, gt_cls)

                if phase == 'train':
                    total_loss.backward()
                    self.optimizer.step()

            cls_loss_.update(cls_loss, data.size(0))
            bbox_loss_.update(bbox_loss, data.size(0))
            total_loss_.update(total_loss, data.size(0))
            accuracy_.update(accuracy, data.size(0))
            
            if batch_idx % 40 == 0:
                print('{} Epoch: {} [{:08d}/{:08d} ({:02.0f}%)]\tLoss: {:.6f} cls Loss: {:.6f} offset Loss:{:.6f}\tAccuracy: {:.6f} LR:{:.7f}'.format(
                    phase, epoch, batch_idx * len(data), len(self.dataloaders[phase].dataset),
                    100. * batch_idx / len(self.dataloaders[phase]), total_loss.item(), cls_loss.item(), bbox_loss.item(), accuracy.item(), self.optimizer.param_groups[0]['lr']))
        
        if phase == 'train':
            self.scheduler.step()
        
        print("{} epoch Loss: {:.6f} cls Loss: {:.6f} bbox Loss: {:.6f} Accuracy: {:.6f}".format(
            phase, total_loss_.avg, cls_loss_.avg, bbox_loss_.avg, accuracy_.avg))
        
        if phase == 'val' and total_loss_.avg < self.best_val_loss:
            self.best_val_loss = total_loss_.avg
            torch.save(self.model.state_dict(), './ckpts/best_rnet.pth')
            torch.save(self.model, './ckpts/best_rnet')

        return cls_loss_.avg, bbox_loss_.avg, total_loss_.avg, accuracy_.avg
示例#2
0
def measure_cams(config: Config):
    config_json = config.toDictionary()
    print('measure_cams')
    print(config_json)
    import os

    from torch.utils.data.dataloader import DataLoader
    from data.loader_segmentation import Segmentation
    from artifacts.artifact_manager import artifact_manager
    from multiprocessing import Pool

    # Set up data loader
    dataloader = DataLoader(Segmentation(
        config.eval_dataset_root,
        source='train',
        augmentation='val',
        image_size=config.classifier_image_size,
        requested_labels=['classification', 'segmentation']),
                            batch_size=config.cams_measure_batch_size,
                            shuffle=False,
                            pin_memory=False,
                            num_workers=2,
                            prefetch_factor=2)

    # Get cams directory
    cam_root_path = os.path.join(artifact_manager.getDir(), 'cam')
    label_cam_path = os.path.join(artifact_manager.getDir(), 'labels_cam')

    count = 0

    wandb.init(entity='kobus_wits',
               project='wass_measure_cams',
               name=config.sweep_id + '_cam_' + config.classifier_name,
               config=config_json)
    avg_meter = AverageMeter('accuracy', 'mapr', 'miou')

    for batch_no, batch in enumerate(dataloader):
        datapacket = batch[2]

        payloads = []
        for image_no, image_name in enumerate(datapacket['image_name']):
            payload = {
                'count': count,
                'image_path': datapacket['image_path'][image_no],
                'label_path': datapacket['label_path'][image_no],
                'predi_path': os.path.join(label_cam_path,
                                           image_name + '.png'),
                'cam_path': os.path.join(cam_root_path, image_name + '.png'),
            }
            payloads.append(payload)
            count += 1
            print('Measure cam : ', count, end='\r')

        with Pool(8) as poel:
            logs = poel.map(_measure_sample, payloads)

            for log in logs:
                avg_meter.add({
                    'accuracy': log['accuracy'],
                    'mapr': log['mapr'],
                    'miou': log['miou'],
                })

                if log['count'] < 8:
                    wandb.log(log, step=log['count'])

            wandb.log({
                'accuracy': avg_meter.get('accuracy'),
                'mapr': avg_meter.get('mapr'),
                'miou': avg_meter.get('miou'),
            })

    wandb.finish()