示例#1
0
class PNetPruner(object):
    def __init__(self, epochs, dataloaders, model, optimizer, scheduler,
                 device, prune_ratio, finetune_epochs):
        self.epochs = epochs
        self.dataloaders = dataloaders
        self.model = model
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.device = device
        self.lossfn = Loss(self.device)

        self.prune_iters = self._estimate_pruning_iterations(
            model, prune_ratio)
        print("Total prunning iterations:", self.prune_iters)
        self.finetune_epochs = finetune_epochs

    def compute_accuracy(self, prob_cls, gt_cls):
        # we only need the detection which >= 0
        prob_cls = torch.squeeze(prob_cls)
        mask = torch.ge(gt_cls, 0)

        # get valid elements
        valid_gt_cls = gt_cls[mask]
        valid_prob_cls = prob_cls[mask]
        size = min(valid_gt_cls.size()[0], valid_prob_cls.size()[0])

        # get max index with softmax layer
        _, valid_pred_cls = torch.max(valid_prob_cls, dim=1)

        right_ones = torch.eq(valid_pred_cls.float(),
                              valid_gt_cls.float()).float()

        return torch.div(torch.mul(torch.sum(right_ones), float(1.0)),
                         float(size))

    def prune(self):
        print("Before Prunning...")
        self.train_epoch(0, 'val')
        for i in range(self.prune_iters):
            self.prune_step()
            print("After Prunning Iter ", i)
            self.train_epoch(i, 'val')
            print("Finetuning...")
            for epoch in range(self.finetune_epochs):
                self.train_epoch(i, 'train')
                self.train_epoch(i, 'val')
            torch.save(self.model.state_dict(),
                       './prunning/results/pruned_pnet.pth')
            torch.onnx.export(self.model,
                              torch.randn(1, 3, 12, 12).to(self.device),
                              './onnx2ncnn/pruned_pnet.onnx',
                              input_names=['input'],
                              output_names=['scores', 'offsets'])

    def prune_step(self):
        self.model.train()

        sample_idx = np.random.randint(0, len(self.dataloaders['train']))
        for batch_idx, sample in enumerate(self.dataloaders['train']):
            if batch_idx == sample_idx:
                data = sample['input_img']
                gt_cls = sample['cls_target']
                gt_bbox = sample['bbox_target']

        data, gt_cls, gt_bbox = data.to(self.device), gt_cls.to(
            self.device), gt_bbox.to(self.device).float()
        pred_cls, pred_bbox = self.model(data)
        cls_loss = self.lossfn.cls_loss(gt_cls, pred_cls)
        bbox_loss = self.lossfn.box_loss(gt_cls, gt_bbox, pred_bbox)
        total_loss = cls_loss + 5 * bbox_loss
        total_loss.backward()
        self.model.prune(self.device)

    def train_epoch(self, epoch, phase):
        cls_loss_ = AverageMeter()
        bbox_loss_ = AverageMeter()
        total_loss_ = AverageMeter()
        accuracy_ = AverageMeter()
        if phase == 'train':
            self.model.train()
        else:
            self.model.eval()

        for batch_idx, sample in enumerate(self.dataloaders[phase]):
            data = sample['input_img']
            gt_cls = sample['cls_target']
            gt_bbox = sample['bbox_target']
            data, gt_cls, gt_bbox = data.to(self.device), gt_cls.to(
                self.device), gt_bbox.to(self.device).float()

            self.optimizer.zero_grad()
            with torch.set_grad_enabled(phase == 'train'):
                pred_cls, pred_bbox = self.model(data)

                # compute the cls loss and bbox loss and weighted them together
                cls_loss = self.lossfn.cls_loss(gt_cls, pred_cls)
                bbox_loss = self.lossfn.box_loss(gt_cls, gt_bbox, pred_bbox)
                total_loss = cls_loss + 5 * bbox_loss

                # compute clssification accuracy
                accuracy = self.compute_accuracy(pred_cls, gt_cls)

                if phase == 'train':
                    total_loss.backward()
                    self.optimizer.step()

            cls_loss_.update(cls_loss, data.size(0))
            bbox_loss_.update(bbox_loss, data.size(0))
            total_loss_.update(total_loss, data.size(0))
            accuracy_.update(accuracy, data.size(0))

            #if batch_idx % 40 == 0:
            #    print('{} Epoch: {} [{:08d}/{:08d} ({:02.0f}%)]\tLoss: {:.6f} cls Loss: {:.6f} offset Loss:{:.6f}\tAccuracy: {:.6f} LR:{:.7f}'.format(
            #        phase, epoch, batch_idx * len(data), len(self.dataloaders[phase].dataset),
            #        100. * batch_idx / len(self.dataloaders[phase]), total_loss.item(), cls_loss.item(), bbox_loss.item(), accuracy.item(), self.optimizer.param_groups[0]['lr']))

        print(
            "{} epoch Loss: {:.6f} cls Loss: {:.6f} bbox Loss: {:.6f} Accuracy: {:.6f}"
            .format(phase, total_loss_.avg, cls_loss_.avg, bbox_loss_.avg,
                    accuracy_.avg))

        # torch.save(self.model.state_dict(), './pretrained_weights/quant_mtcnn/best_pnet.pth')

        return cls_loss_.avg, bbox_loss_.avg, total_loss_.avg, accuracy_.avg

    def _estimate_pruning_iterations(self, model, prune_ratio):
        '''Estimate how many feature maps to prune using estimated params per 
        feature map divide by total param to prune, since we only prune 1 filter
        at a time, iterations should equal to total filters to prune
        
        Parameters:
        -----------
        model: pytorch model
        prune_ratio: ration of total params to prune
        
        Return: 
        -------
        num of iterations of pruning
        '''
        # we only prune Conv2d layers here, Linear layer will be considered later
        conv2ds = [
            module for module in model.modules()
            if issubclass(type(module), nn.Conv2d)
        ]
        num_feature_maps = np.sum(conv2d.out_channels for conv2d in conv2ds)

        conv2d_params = (module.parameters() for module in model.modules()
                         if issubclass(type(module), nn.Conv2d))
        param_objs = itertools.chain(*conv2d_params)
        # num_param: in * out * w * h per feature map
        num_params = np.sum(np.prod(np.array(p.size())) for p in param_objs)

        params_per_map = num_params // num_feature_maps

        return int(np.ceil(num_params * prune_ratio / params_per_map))
class ONetTrainer(object):
    
    def __init__(self, epochs, dataloaders, model, optimizer, scheduler, device):
        self.epochs = epochs
        self.dataloaders = dataloaders
        self.model = model
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.device = device
        self.lossfn = Loss(self.device)
        
        # save best model
        self.best_val_loss = 100

    def compute_accuracy(self, prob_cls, gt_cls):
        # we only need the detection which >= 0
        prob_cls = torch.squeeze(prob_cls)
        tmp_gt_cls = gt_cls.detach().clone()
        tmp_gt_cls[tmp_gt_cls==-2] = 1
        mask = torch.ge(tmp_gt_cls, 0)
        
        # get valid elements
        valid_gt_cls = tmp_gt_cls[mask]
        valid_prob_cls = prob_cls[mask]
        size = min(valid_gt_cls.size()[0], valid_prob_cls.size()[0])
        
        # get max index with softmax layer
        _, valid_pred_cls = torch.max(valid_prob_cls, dim=1)
        
        right_ones = torch.eq(valid_pred_cls.float(), valid_gt_cls.float()).float()

        return torch.div(torch.mul(torch.sum(right_ones), float(1.0)), float(size))
    
    def train(self):
        for epoch in range(self.epochs):
            self.train_epoch(epoch, 'train')
            self.train_epoch(epoch, 'val')

        
    def train_epoch(self, epoch, phase):
        cls_loss_ = AverageMeter()
        bbox_loss_ = AverageMeter()
        landmark_loss_ = AverageMeter()
        total_loss_ = AverageMeter()
        accuracy_ = AverageMeter()
        
        if phase == 'train':
            self.model.train()
        else:
            self.model.eval()

        for batch_idx, sample in enumerate(self.dataloaders[phase]):
            data = sample['input_img']
            gt_cls = sample['cls_target']
            gt_bbox = sample['bbox_target']
            gt_landmark = sample['landmark_target']
            
            data, gt_cls, gt_bbox, gt_landmark = data.to(self.device), \
                gt_cls.to(self.device), gt_bbox.to(self.device).float(), \
                gt_landmark.to(self.device).float()

            self.optimizer.zero_grad()
            with torch.set_grad_enabled(phase == 'train'):
                pred_cls, pred_bbox, pred_landmark = self.model(data)
                # compute the cls loss and bbox loss and weighted them together
                cls_loss = self.lossfn.cls_loss(gt_cls, pred_cls)
                bbox_loss = self.lossfn.box_loss(gt_cls, gt_bbox, pred_bbox)
                landmark_loss = self.lossfn.landmark_loss(gt_cls, gt_landmark, pred_landmark)
                total_loss = cls_loss + 20*bbox_loss + 20*landmark_loss
                
                # compute clssification accuracy
                accuracy = self.compute_accuracy(pred_cls, gt_cls)

                if phase == 'train':
                    total_loss.backward()
                    self.optimizer.step()

            cls_loss_.update(cls_loss, data.size(0))
            bbox_loss_.update(bbox_loss, data.size(0))
            landmark_loss_.update(landmark_loss, data.size(0))
            total_loss_.update(total_loss, data.size(0))
            accuracy_.update(accuracy, data.size(0))
             
            if batch_idx % 40 == 0:
                print('{} Epoch: {} [{:08d}/{:08d} ({:02.0f}%)]\tLoss: {:.6f} cls Loss: {:.6f} offset Loss:{:.6f} landmark Loss: {:.6f}\tAccuracy: {:.6f} LR:{:.7f}'.format(
                    phase, epoch, batch_idx * len(data), len(self.dataloaders[phase].dataset),
                    100. * batch_idx / len(self.dataloaders[phase]), total_loss.item(), cls_loss.item(), bbox_loss.item(), landmark_loss.item(), accuracy.item(), self.optimizer.param_groups[0]['lr']))
        
        if phase == 'train':
            self.scheduler.step()
        
        print("{} epoch Loss: {:.6f} cls Loss: {:.6f} bbox Loss: {:.6f} landmark Loss: {:.6f} Accuracy: {:.6f}".format(
            phase, total_loss_.avg, cls_loss_.avg, bbox_loss_.avg, landmark_loss_.avg, accuracy_.avg))
        
        if phase == 'val' and total_loss_.avg < self.best_val_loss:
            self.best_val_loss = total_loss_.avg
            torch.save(self.model.state_dict(), './pretrained_weights/mtcnn/best_onet_landmark_2.pth')
        
        return cls_loss_.avg, bbox_loss_.avg, total_loss_.avg, landmark_loss_.avg, accuracy_.avg
示例#3
0
class PNetTrainer(object):
    def __init__(self, lr, train_loader, model, optimizer, scheduler, logger,
                 device):
        self.lr = lr
        self.train_loader = train_loader
        self.model = model
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.device = device
        self.lossfn = Loss(self.device)
        self.logger = logger
        self.run_count = 0
        self.scalar_info = {}

    def compute_accuracy(self, prob_cls, gt_cls):
        # we only need the detection which >= 0
        prob_cls = torch.squeeze(prob_cls)
        mask = torch.ge(gt_cls, 0)
        # get valid element
        valid_gt_cls = torch.masked_select(gt_cls, mask)
        valid_prob_cls = torch.masked_select(prob_cls, mask)
        size = min(valid_gt_cls.size()[0], valid_prob_cls.size()[0])
        prob_ones = torch.ge(valid_prob_cls, 0.6).float()
        right_ones = torch.eq(prob_ones, valid_gt_cls.float()).float()

        return torch.div(torch.mul(torch.sum(right_ones), float(1.0)),
                         float(size))

    def update_lr(self, epoch):
        """
        update learning rate of optimizers
        :param epoch: current training epoch
        """
        # update learning rate of model optimizer
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = self.lr

    def train(self, epoch):
        cls_loss_ = AverageMeter()
        box_offset_loss_ = AverageMeter()
        total_loss_ = AverageMeter()
        accuracy_ = AverageMeter()

        self.scheduler.step()
        self.model.train()

        for batch_idx, (data, target) in enumerate(self.train_loader):
            gt_label = target['label']
            gt_bbox = target['bbox_target']
            data, gt_label, gt_bbox = data.to(self.device), gt_label.to(
                self.device), gt_bbox.to(self.device).float()

            cls_pred, box_offset_pred = self.model(data)
            # compute the loss
            cls_loss = self.lossfn.cls_loss(gt_label, cls_pred)
            box_offset_loss = self.lossfn.box_loss(gt_label, gt_bbox,
                                                   box_offset_pred)

            total_loss = cls_loss + box_offset_loss * 0.5
            accuracy = self.compute_accuracy(cls_pred, gt_label)

            self.optimizer.zero_grad()
            total_loss.backward()
            self.optimizer.step()

            cls_loss_.update(cls_loss, data.size(0))
            box_offset_loss_.update(box_offset_loss, data.size(0))
            total_loss_.update(total_loss, data.size(0))
            accuracy_.update(accuracy, data.size(0))

            print(
                'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAccuracy: {:.6f}'
                .format(epoch, batch_idx * len(data),
                        len(self.train_loader.dataset),
                        100. * batch_idx / len(self.train_loader),
                        total_loss.item(), accuracy.item()))

        self.scalar_info['cls_loss'] = cls_loss_.avg
        self.scalar_info['box_offset_loss'] = box_offset_loss_.avg
        self.scalar_info['total_loss'] = total_loss_.avg
        self.scalar_info['accuracy'] = accuracy_.avg
        self.scalar_info['lr'] = self.scheduler.get_lr()[0]

        if self.logger is not None:
            for tag, value in list(self.scalar_info.items()):
                self.logger.scalar_summary(tag, value, self.run_count)
            self.scalar_info = {}
        self.run_count += 1

        print("|===>Loss: {:.4f}".format(total_loss_.avg))
        return cls_loss_.avg, box_offset_loss_.avg, total_loss_.avg, accuracy_.avg