class PNetPruner(object): def __init__(self, epochs, dataloaders, model, optimizer, scheduler, device, prune_ratio, finetune_epochs): self.epochs = epochs self.dataloaders = dataloaders self.model = model self.optimizer = optimizer self.scheduler = scheduler self.device = device self.lossfn = Loss(self.device) self.prune_iters = self._estimate_pruning_iterations( model, prune_ratio) print("Total prunning iterations:", self.prune_iters) self.finetune_epochs = finetune_epochs def compute_accuracy(self, prob_cls, gt_cls): # we only need the detection which >= 0 prob_cls = torch.squeeze(prob_cls) mask = torch.ge(gt_cls, 0) # get valid elements valid_gt_cls = gt_cls[mask] valid_prob_cls = prob_cls[mask] size = min(valid_gt_cls.size()[0], valid_prob_cls.size()[0]) # get max index with softmax layer _, valid_pred_cls = torch.max(valid_prob_cls, dim=1) right_ones = torch.eq(valid_pred_cls.float(), valid_gt_cls.float()).float() return torch.div(torch.mul(torch.sum(right_ones), float(1.0)), float(size)) def prune(self): print("Before Prunning...") self.train_epoch(0, 'val') for i in range(self.prune_iters): self.prune_step() print("After Prunning Iter ", i) self.train_epoch(i, 'val') print("Finetuning...") for epoch in range(self.finetune_epochs): self.train_epoch(i, 'train') self.train_epoch(i, 'val') torch.save(self.model.state_dict(), './prunning/results/pruned_pnet.pth') torch.onnx.export(self.model, torch.randn(1, 3, 12, 12).to(self.device), './onnx2ncnn/pruned_pnet.onnx', input_names=['input'], output_names=['scores', 'offsets']) def prune_step(self): self.model.train() sample_idx = np.random.randint(0, len(self.dataloaders['train'])) for batch_idx, sample in enumerate(self.dataloaders['train']): if batch_idx == sample_idx: data = sample['input_img'] gt_cls = sample['cls_target'] gt_bbox = sample['bbox_target'] data, gt_cls, gt_bbox = data.to(self.device), gt_cls.to( self.device), gt_bbox.to(self.device).float() pred_cls, pred_bbox = self.model(data) cls_loss = self.lossfn.cls_loss(gt_cls, pred_cls) bbox_loss = self.lossfn.box_loss(gt_cls, gt_bbox, pred_bbox) total_loss = cls_loss + 5 * bbox_loss total_loss.backward() self.model.prune(self.device) def train_epoch(self, epoch, phase): cls_loss_ = AverageMeter() bbox_loss_ = AverageMeter() total_loss_ = AverageMeter() accuracy_ = AverageMeter() if phase == 'train': self.model.train() else: self.model.eval() for batch_idx, sample in enumerate(self.dataloaders[phase]): data = sample['input_img'] gt_cls = sample['cls_target'] gt_bbox = sample['bbox_target'] data, gt_cls, gt_bbox = data.to(self.device), gt_cls.to( self.device), gt_bbox.to(self.device).float() self.optimizer.zero_grad() with torch.set_grad_enabled(phase == 'train'): pred_cls, pred_bbox = self.model(data) # compute the cls loss and bbox loss and weighted them together cls_loss = self.lossfn.cls_loss(gt_cls, pred_cls) bbox_loss = self.lossfn.box_loss(gt_cls, gt_bbox, pred_bbox) total_loss = cls_loss + 5 * bbox_loss # compute clssification accuracy accuracy = self.compute_accuracy(pred_cls, gt_cls) if phase == 'train': total_loss.backward() self.optimizer.step() cls_loss_.update(cls_loss, data.size(0)) bbox_loss_.update(bbox_loss, data.size(0)) total_loss_.update(total_loss, data.size(0)) accuracy_.update(accuracy, data.size(0)) #if batch_idx % 40 == 0: # print('{} Epoch: {} [{:08d}/{:08d} ({:02.0f}%)]\tLoss: {:.6f} cls Loss: {:.6f} offset Loss:{:.6f}\tAccuracy: {:.6f} LR:{:.7f}'.format( # phase, epoch, batch_idx * len(data), len(self.dataloaders[phase].dataset), # 100. * batch_idx / len(self.dataloaders[phase]), total_loss.item(), cls_loss.item(), bbox_loss.item(), accuracy.item(), self.optimizer.param_groups[0]['lr'])) print( "{} epoch Loss: {:.6f} cls Loss: {:.6f} bbox Loss: {:.6f} Accuracy: {:.6f}" .format(phase, total_loss_.avg, cls_loss_.avg, bbox_loss_.avg, accuracy_.avg)) # torch.save(self.model.state_dict(), './pretrained_weights/quant_mtcnn/best_pnet.pth') return cls_loss_.avg, bbox_loss_.avg, total_loss_.avg, accuracy_.avg def _estimate_pruning_iterations(self, model, prune_ratio): '''Estimate how many feature maps to prune using estimated params per feature map divide by total param to prune, since we only prune 1 filter at a time, iterations should equal to total filters to prune Parameters: ----------- model: pytorch model prune_ratio: ration of total params to prune Return: ------- num of iterations of pruning ''' # we only prune Conv2d layers here, Linear layer will be considered later conv2ds = [ module for module in model.modules() if issubclass(type(module), nn.Conv2d) ] num_feature_maps = np.sum(conv2d.out_channels for conv2d in conv2ds) conv2d_params = (module.parameters() for module in model.modules() if issubclass(type(module), nn.Conv2d)) param_objs = itertools.chain(*conv2d_params) # num_param: in * out * w * h per feature map num_params = np.sum(np.prod(np.array(p.size())) for p in param_objs) params_per_map = num_params // num_feature_maps return int(np.ceil(num_params * prune_ratio / params_per_map))
class ONetTrainer(object): def __init__(self, epochs, dataloaders, model, optimizer, scheduler, device): self.epochs = epochs self.dataloaders = dataloaders self.model = model self.optimizer = optimizer self.scheduler = scheduler self.device = device self.lossfn = Loss(self.device) # save best model self.best_val_loss = 100 def compute_accuracy(self, prob_cls, gt_cls): # we only need the detection which >= 0 prob_cls = torch.squeeze(prob_cls) tmp_gt_cls = gt_cls.detach().clone() tmp_gt_cls[tmp_gt_cls==-2] = 1 mask = torch.ge(tmp_gt_cls, 0) # get valid elements valid_gt_cls = tmp_gt_cls[mask] valid_prob_cls = prob_cls[mask] size = min(valid_gt_cls.size()[0], valid_prob_cls.size()[0]) # get max index with softmax layer _, valid_pred_cls = torch.max(valid_prob_cls, dim=1) right_ones = torch.eq(valid_pred_cls.float(), valid_gt_cls.float()).float() return torch.div(torch.mul(torch.sum(right_ones), float(1.0)), float(size)) def train(self): for epoch in range(self.epochs): self.train_epoch(epoch, 'train') self.train_epoch(epoch, 'val') def train_epoch(self, epoch, phase): cls_loss_ = AverageMeter() bbox_loss_ = AverageMeter() landmark_loss_ = AverageMeter() total_loss_ = AverageMeter() accuracy_ = AverageMeter() if phase == 'train': self.model.train() else: self.model.eval() for batch_idx, sample in enumerate(self.dataloaders[phase]): data = sample['input_img'] gt_cls = sample['cls_target'] gt_bbox = sample['bbox_target'] gt_landmark = sample['landmark_target'] data, gt_cls, gt_bbox, gt_landmark = data.to(self.device), \ gt_cls.to(self.device), gt_bbox.to(self.device).float(), \ gt_landmark.to(self.device).float() self.optimizer.zero_grad() with torch.set_grad_enabled(phase == 'train'): pred_cls, pred_bbox, pred_landmark = self.model(data) # compute the cls loss and bbox loss and weighted them together cls_loss = self.lossfn.cls_loss(gt_cls, pred_cls) bbox_loss = self.lossfn.box_loss(gt_cls, gt_bbox, pred_bbox) landmark_loss = self.lossfn.landmark_loss(gt_cls, gt_landmark, pred_landmark) total_loss = cls_loss + 20*bbox_loss + 20*landmark_loss # compute clssification accuracy accuracy = self.compute_accuracy(pred_cls, gt_cls) if phase == 'train': total_loss.backward() self.optimizer.step() cls_loss_.update(cls_loss, data.size(0)) bbox_loss_.update(bbox_loss, data.size(0)) landmark_loss_.update(landmark_loss, data.size(0)) total_loss_.update(total_loss, data.size(0)) accuracy_.update(accuracy, data.size(0)) if batch_idx % 40 == 0: print('{} Epoch: {} [{:08d}/{:08d} ({:02.0f}%)]\tLoss: {:.6f} cls Loss: {:.6f} offset Loss:{:.6f} landmark Loss: {:.6f}\tAccuracy: {:.6f} LR:{:.7f}'.format( phase, epoch, batch_idx * len(data), len(self.dataloaders[phase].dataset), 100. * batch_idx / len(self.dataloaders[phase]), total_loss.item(), cls_loss.item(), bbox_loss.item(), landmark_loss.item(), accuracy.item(), self.optimizer.param_groups[0]['lr'])) if phase == 'train': self.scheduler.step() print("{} epoch Loss: {:.6f} cls Loss: {:.6f} bbox Loss: {:.6f} landmark Loss: {:.6f} Accuracy: {:.6f}".format( phase, total_loss_.avg, cls_loss_.avg, bbox_loss_.avg, landmark_loss_.avg, accuracy_.avg)) if phase == 'val' and total_loss_.avg < self.best_val_loss: self.best_val_loss = total_loss_.avg torch.save(self.model.state_dict(), './pretrained_weights/mtcnn/best_onet_landmark_2.pth') return cls_loss_.avg, bbox_loss_.avg, total_loss_.avg, landmark_loss_.avg, accuracy_.avg
class PNetTrainer(object): def __init__(self, lr, train_loader, model, optimizer, scheduler, logger, device): self.lr = lr self.train_loader = train_loader self.model = model self.optimizer = optimizer self.scheduler = scheduler self.device = device self.lossfn = Loss(self.device) self.logger = logger self.run_count = 0 self.scalar_info = {} def compute_accuracy(self, prob_cls, gt_cls): # we only need the detection which >= 0 prob_cls = torch.squeeze(prob_cls) mask = torch.ge(gt_cls, 0) # get valid element valid_gt_cls = torch.masked_select(gt_cls, mask) valid_prob_cls = torch.masked_select(prob_cls, mask) size = min(valid_gt_cls.size()[0], valid_prob_cls.size()[0]) prob_ones = torch.ge(valid_prob_cls, 0.6).float() right_ones = torch.eq(prob_ones, valid_gt_cls.float()).float() return torch.div(torch.mul(torch.sum(right_ones), float(1.0)), float(size)) def update_lr(self, epoch): """ update learning rate of optimizers :param epoch: current training epoch """ # update learning rate of model optimizer for param_group in self.optimizer.param_groups: param_group['lr'] = self.lr def train(self, epoch): cls_loss_ = AverageMeter() box_offset_loss_ = AverageMeter() total_loss_ = AverageMeter() accuracy_ = AverageMeter() self.scheduler.step() self.model.train() for batch_idx, (data, target) in enumerate(self.train_loader): gt_label = target['label'] gt_bbox = target['bbox_target'] data, gt_label, gt_bbox = data.to(self.device), gt_label.to( self.device), gt_bbox.to(self.device).float() cls_pred, box_offset_pred = self.model(data) # compute the loss cls_loss = self.lossfn.cls_loss(gt_label, cls_pred) box_offset_loss = self.lossfn.box_loss(gt_label, gt_bbox, box_offset_pred) total_loss = cls_loss + box_offset_loss * 0.5 accuracy = self.compute_accuracy(cls_pred, gt_label) self.optimizer.zero_grad() total_loss.backward() self.optimizer.step() cls_loss_.update(cls_loss, data.size(0)) box_offset_loss_.update(box_offset_loss, data.size(0)) total_loss_.update(total_loss, data.size(0)) accuracy_.update(accuracy, data.size(0)) print( 'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAccuracy: {:.6f}' .format(epoch, batch_idx * len(data), len(self.train_loader.dataset), 100. * batch_idx / len(self.train_loader), total_loss.item(), accuracy.item())) self.scalar_info['cls_loss'] = cls_loss_.avg self.scalar_info['box_offset_loss'] = box_offset_loss_.avg self.scalar_info['total_loss'] = total_loss_.avg self.scalar_info['accuracy'] = accuracy_.avg self.scalar_info['lr'] = self.scheduler.get_lr()[0] if self.logger is not None: for tag, value in list(self.scalar_info.items()): self.logger.scalar_summary(tag, value, self.run_count) self.scalar_info = {} self.run_count += 1 print("|===>Loss: {:.4f}".format(total_loss_.avg)) return cls_loss_.avg, box_offset_loss_.avg, total_loss_.avg, accuracy_.avg