def _load_KLDIV_loss_function(device): """ Load loss function and its utilities. """ loss_fct = torch.nn.KLDivLoss(reduction='batchmean') softmax = Softmax(dim=-1) logSoftmax = LogSoftmax(dim=-1) loss_fct.to(device) softmax.to(device) logSoftmax.to(device) return loss_fct, softmax, logSoftmax
def model_train(self, epoch_offset=0, lamda=10, nreg=2400, ncls=256): LOGGER.info('Started Training with an offset of %s', str(epoch_offset)) create_dir(MODEL_SAVE_PATH) optimizer = SGD(self.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM) scheduler = StepLR(optimizer, step_size=SCHEDULER_STEP, gamma=SCHEDULER_GAMMA) LOGGER.info( 'Learning Rate: %s, Momentum: %s, Scheduler_step: %s, scheduler_gamma: %s', str(LEARNING_RATE), str(MOMENTUM), str(SCHEDULER_STEP), str(SCHEDULER_GAMMA)) loss_for_classification = NLLLoss() loss_for_regression = SmoothL1Loss() img_coors_json = read_json_file(BBOX_XYWH_JSON_PATH) anchor_box = AnchorBox() all_background_index = [] all_foreground_index = [] all_reg_tensor = [] for image, coors in img_coors_json.items(): li_fore_index, li_back_index, reg_ten_actual = \ anchor_box.calculate_p_for_each_anchor_box(anchor_box.anchor_boxes, coors) all_background_index.append(li_back_index) all_foreground_index.append(li_fore_index) all_reg_tensor.append(reg_ten_actual) for epoch in range(EPOCHS): epoch_loss = 0.0 scheduler.step(epoch) LOGGER.debug('Epoch: %s, Current Learning Rate: %s', str(epoch + epoch_offset), str(scheduler.get_lr())) count = 0 for image, coors in img_coors_json.items(): img = cv2.imread(NORMALISED_IMAGES_PATH + image) img = torch.tensor(img).float().permute(2, 0, 1).unsqueeze(0) img = img.to(self.device) pred_cls, pred_reg = self.forward(img) li_foreground_index = all_foreground_index[count] li_background_index = all_background_index[count] reg_tensor_actual = all_reg_tensor[count] count = count + 1 exp_torch_fg_bg = [] pred_torch_fg = torch.zeros(1, pred_cls.shape[2]) pred_torch_fg = pred_torch_fg.to(self.device) pred_torch_reg = torch.zeros(1, pred_reg.shape[2]) pred_torch_reg = pred_torch_reg.to(self.device) for idx_foreground in li_foreground_index: exp_torch_fg_bg.append(1) pred_torch_fg = torch.cat( (pred_torch_fg, pred_cls[0][idx_foreground].unsqueeze(0)), dim=0) pred_torch_reg = torch.cat( (pred_torch_reg, pred_reg[0][idx_foreground].unsqueeze(0)), dim=0) pred_torch_fg = pred_torch_fg[1:] pred_torch_reg = pred_torch_reg[1:] pred_torch_bg = torch.zeros(1, pred_cls.shape[2]) pred_torch_bg = pred_torch_bg.to(self.device) for idx_background in li_background_index: exp_torch_fg_bg.append(0) pred_torch_bg = torch.cat( (pred_torch_bg, pred_cls[0][idx_background].unsqueeze(0)), dim=0) pred_torch_bg = pred_torch_bg[1:] pred_cls_only_background_foreground = torch.cat( (pred_torch_fg, pred_torch_bg), dim=0) pred_cls_only_background_foreground = LogSoftmax(dim=1).\ forward(pred_cls_only_background_foreground) exp_torch_fg_bg = torch.tensor(exp_torch_fg_bg) exp_torch_fg_bg = exp_torch_fg_bg.to(self.device) pred_cls_only_background_foreground = pred_cls_only_background_foreground.to( self.device) reg_tensor_actual = reg_tensor_actual.to(self.device) pred_torch_reg = pred_torch_reg.to(self.device) cls_loss = loss_for_classification( pred_cls_only_background_foreground, exp_torch_fg_bg) reg_loss = loss_for_regression(reg_tensor_actual, pred_torch_reg) total_image_loss = (cls_loss / ncls) + (reg_loss * lamda / nreg) total_image_loss = total_image_loss.to(self.device) optimizer.zero_grad() total_image_loss.backward() optimizer.step() epoch_loss = epoch_loss + total_image_loss.item() LOGGER.debug('Loss at Epoch %s: %s', str(epoch + epoch_offset), str(epoch_loss)) if epoch % EPOCH_SAVE_INTERVAL == 0: torch.save( self.state_dict(), MODEL_SAVE_PATH + 'model_epc_' + str(epoch + epoch_offset) + '.pt') if epoch % 5 == 0: LOGGER.info('Loss at Epoch %s: %s', str(epoch + epoch_offset), str(epoch_loss))