class ONetTrainer(object): def __init__(self, epochs, dataloaders, model, optimizer, scheduler, device): self.epochs = epochs self.dataloaders = dataloaders self.model = model self.optimizer = optimizer self.scheduler = scheduler self.device = device self.lossfn = Loss(self.device) # save best model self.best_val_loss = 100 def compute_accuracy(self, prob_cls, gt_cls): # we only need the detection which >= 0 prob_cls = torch.squeeze(prob_cls) tmp_gt_cls = gt_cls.detach().clone() tmp_gt_cls[tmp_gt_cls==-2] = 1 mask = torch.ge(tmp_gt_cls, 0) # get valid elements valid_gt_cls = tmp_gt_cls[mask] valid_prob_cls = prob_cls[mask] size = min(valid_gt_cls.size()[0], valid_prob_cls.size()[0]) # get max index with softmax layer _, valid_pred_cls = torch.max(valid_prob_cls, dim=1) right_ones = torch.eq(valid_pred_cls.float(), valid_gt_cls.float()).float() return torch.div(torch.mul(torch.sum(right_ones), float(1.0)), float(size)) def train(self): for epoch in range(self.epochs): self.train_epoch(epoch, 'train') self.train_epoch(epoch, 'val') def train_epoch(self, epoch, phase): cls_loss_ = AverageMeter() bbox_loss_ = AverageMeter() landmark_loss_ = AverageMeter() total_loss_ = AverageMeter() accuracy_ = AverageMeter() if phase == 'train': self.model.train() else: self.model.eval() for batch_idx, sample in enumerate(self.dataloaders[phase]): data = sample['input_img'] gt_cls = sample['cls_target'] gt_bbox = sample['bbox_target'] gt_landmark = sample['landmark_target'] data, gt_cls, gt_bbox, gt_landmark = data.to(self.device), \ gt_cls.to(self.device), gt_bbox.to(self.device).float(), \ gt_landmark.to(self.device).float() self.optimizer.zero_grad() with torch.set_grad_enabled(phase == 'train'): pred_cls, pred_bbox, pred_landmark = self.model(data) # compute the cls loss and bbox loss and weighted them together cls_loss = self.lossfn.cls_loss(gt_cls, pred_cls) bbox_loss = self.lossfn.box_loss(gt_cls, gt_bbox, pred_bbox) landmark_loss = self.lossfn.landmark_loss(gt_cls, gt_landmark, pred_landmark) total_loss = cls_loss + 20*bbox_loss + 20*landmark_loss # compute clssification accuracy accuracy = self.compute_accuracy(pred_cls, gt_cls) if phase == 'train': total_loss.backward() self.optimizer.step() cls_loss_.update(cls_loss, data.size(0)) bbox_loss_.update(bbox_loss, data.size(0)) landmark_loss_.update(landmark_loss, data.size(0)) total_loss_.update(total_loss, data.size(0)) accuracy_.update(accuracy, data.size(0)) if batch_idx % 40 == 0: print('{} Epoch: {} [{:08d}/{:08d} ({:02.0f}%)]\tLoss: {:.6f} cls Loss: {:.6f} offset Loss:{:.6f} landmark Loss: {:.6f}\tAccuracy: {:.6f} LR:{:.7f}'.format( phase, epoch, batch_idx * len(data), len(self.dataloaders[phase].dataset), 100. * batch_idx / len(self.dataloaders[phase]), total_loss.item(), cls_loss.item(), bbox_loss.item(), landmark_loss.item(), accuracy.item(), self.optimizer.param_groups[0]['lr'])) if phase == 'train': self.scheduler.step() print("{} epoch Loss: {:.6f} cls Loss: {:.6f} bbox Loss: {:.6f} landmark Loss: {:.6f} Accuracy: {:.6f}".format( phase, total_loss_.avg, cls_loss_.avg, bbox_loss_.avg, landmark_loss_.avg, accuracy_.avg)) if phase == 'val' and total_loss_.avg < self.best_val_loss: self.best_val_loss = total_loss_.avg torch.save(self.model.state_dict(), './pretrained_weights/mtcnn/best_onet_landmark_2.pth') return cls_loss_.avg, bbox_loss_.avg, total_loss_.avg, landmark_loss_.avg, accuracy_.avg
class RNetTrainer(object): def __init__(self, lr, train_loader, model, optimizer, scheduler, logger, device): self.lr = lr self.train_loader = train_loader self.model = model self.optimizer = optimizer self.scheduler = scheduler self.device = device self.lossfn = Loss(self.device) self.logger = logger self.run_count = 0 self.scalar_info = {} def compute_accuracy(self, prob_cls, gt_cls): # we only need the detection which >= 0 prob_cls = torch.squeeze(prob_cls) mask = torch.ge(gt_cls, 0) # get valid element valid_gt_cls = torch.masked_select(gt_cls, mask) valid_prob_cls = torch.masked_select(prob_cls, mask) size = min(valid_gt_cls.size()[0], valid_prob_cls.size()[0]) prob_ones = torch.ge(valid_prob_cls, 0.6).float() right_ones = torch.eq(prob_ones, valid_gt_cls.float()).float() return torch.div(torch.mul(torch.sum(right_ones), float(1.0)), float(size)) def update_lr(self, epoch): """ update learning rate of optimizers :param epoch: current training epoch """ # update learning rate of model optimizer for param_group in self.optimizer.param_groups: param_group['lr'] = self.lr def train(self, epoch): cls_loss_ = AverageMeter() box_offset_loss_ = AverageMeter() landmark_loss_ = AverageMeter() total_loss_ = AverageMeter() accuracy_ = AverageMeter() self.scheduler.step() self.model.train() for batch_idx, (data, target) in enumerate(self.train_loader): gt_label = target['label'] gt_bbox = target['bbox_target'] gt_landmark = target['landmark_target'] data, gt_label, gt_bbox, gt_landmark = data.to( self.device), gt_label.to(self.device), gt_bbox.to( self.device).float(), gt_landmark.to(self.device).float() # print(gt_label, gt_bbox, gt_landmark) cls_pred, box_offset_pred, landmark_offset_pred = self.model(data) # compute the loss cls_loss = self.lossfn.cls_loss(gt_label, cls_pred) box_offset_loss = self.lossfn.box_loss(gt_label, gt_bbox, box_offset_pred) landmark_loss = self.lossfn.landmark_loss(gt_label, gt_landmark, landmark_offset_pred) total_loss = cls_loss + 2 * box_offset_loss + landmark_loss accuracy = self.compute_accuracy(cls_pred, gt_label) print("loss:", cls_loss.cpu().detach().numpy(), box_offset_loss.cpu().detach().numpy(), landmark_loss.cpu().detach().numpy(), total_loss.cpu().detach().numpy()) self.optimizer.zero_grad() total_loss.backward() self.optimizer.step() cls_loss_.update(cls_loss, data.size(0)) box_offset_loss_.update(box_offset_loss, data.size(0)) total_loss_.update(total_loss, data.size(0)) landmark_loss_.update(total_loss, data.size(0)) accuracy_.update(accuracy, data.size(0)) print( 'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAccuracy: {:.6f}' .format(epoch, batch_idx * len(data), len(self.train_loader.dataset), 100. * batch_idx / len(self.train_loader), total_loss.item(), accuracy.item())) self.scalar_info['cls_loss'] = cls_loss_.avg self.scalar_info['box_offset_loss'] = box_offset_loss_.avg self.scalar_info['landmark_loss'] = landmark_loss_.avg self.scalar_info['total_loss'] = total_loss_.avg self.scalar_info['accuracy'] = accuracy_.avg self.scalar_info['lr'] = self.scheduler.get_lr()[0] if self.logger is not None: for tag, value in list(self.scalar_info.items()): self.logger.scalar_summary(tag, value, self.run_count) self.scalar_info = {} self.run_count += 1 print("|===>Loss: {:.4f}".format(total_loss_.avg)) return cls_loss_.avg, box_offset_loss_.avg, landmark_loss_.avg, total_loss_.avg, accuracy_.avg