def test(epoch): print("================================") print("Testing after epoch {}".format(epoch)) print("================================") rpn_accuracy_rpn_monitor = [] rpn_accuracy_for_epoch = [] regr_rpn_loss= 0 class_rpn_loss =0 total_rpn_loss = 0 regr_class_loss= 0 class_class_loss =0 total_class_loss = 0 count_rpn = 0 count_class = 0 total_count = 0 for i,(image, boxes, labels , temp, num_pos) in enumerate(test_loader): count_rpn +=1 y_is_box_label = temp[0].to(device=model_rpn_cuda) y_rpn_regr = temp[1].to(device=model_rpn_cuda) image = Variable(image).to(device=model_rpn_cuda) base_x , cls_k , reg_k = model_rpn(image) l1 = rpn_loss_regr(y_true=y_rpn_regr, y_pred=reg_k , y_is_box_label=y_is_box_label , lambda_rpn_regr=args.lambda_rpn_regr, device=model_rpn_cuda) l2 = rpn_loss_cls_fixed_num(y_pred = cls_k , y_is_box_label= y_is_box_label , lambda_rpn_class = args.lambda_rpn_class) regr_rpn_loss += l1.item() class_rpn_loss += l2.item() loss = l1 + l2 total_rpn_loss += loss.item() base_x , cls_k , reg_k = model_rpn(image) for b in range(image.size(0)): img_data = {} rpn_rois = rpn_to_roi(cls_k[b,:].cpu(), reg_k[b,:].cpu(), no_anchors=num_anchors, all_possible_anchor_boxes=all_possible_anchor_boxes_tensor.cpu().clone() ) rpn_rois.to(device=model_classifier_cuda) img_data["boxes"] = boxes[b].to(device=model_classifier_cuda) // downscale img_data['labels'] = labels[b] X2, Y1, Y2, _ = calc_iou(rpn_rois, img_data, class_mapping=config.label_map ) if X2 is None: rpn_accuracy_rpn_monitor.append(0) rpn_accuracy_for_epoch.append(0) continue X2 = X2.to(device=model_classifier_cuda) Y1 = Y1.to(device=model_classifier_cuda) Y2 = Y2.to(device=model_classifier_cuda) count_class += 1 rpn_base = base_x[b].unsqueeze(0) out_class , out_regr = model_classifier(base_x = rpn_base , rois= X2 ) l3 = class_loss_cls(y_true=Y1, y_pred=out_class , lambda_cls_class=args.lambda_cls_class) l4 = class_loss_regr(y_true=Y2, y_pred= out_regr , lambda_cls_regr= args.lambda_cls_regr) loss = l3 + l4 class_class_loss += l3.item() regr_class_loss += l4.item() total_class_loss += loss.item() if args.save_evaluations : total_count += 1 if total_count % 100 == 0 : predicted_boxes = X2 predicted_boxes[:,2] = predicted_boxes[:,2] + predicted_boxes[:,0] predicted_boxes[:,3] = predicted_boxes[:,3] + predicted_boxes[:,1] predicted_boxes = predicted_boxes * downscale temp_img = (denormalize['std'].cpu() * image[b].cpu()) + denormalize['mean'].cpu() save_evaluations_image(image=temp_img, boxes=predicted_boxes, labels=Y1, count=total_count, config=config , save_dir=args.save_dir) if count_class == 0 : count_class = 1 total_class_loss = 0 print('[Test Accuracy] Classifier Model Classification loss: {} Regression loss: {} Total Loss: {} '.format(0, 0 ,0 )) else: print('[Test Accuracy] Classifier Model Classification loss: {} Regression loss: {} Total Loss: {} '.format(class_class_loss / count_class, regr_class_loss / count_class ,total_class_loss/ count_class )) print('[Test Accuracy] RPN Model Classification loss: {} Regression loss: {} Total Loss: {} '.format(class_rpn_loss / count_rpn, regr_rpn_loss / count_rpn ,total_rpn_loss/ count_rpn )) if len(rpn_accuracy_rpn_monitor) == 0 : print('[Test Accuracy] RPN is not producing bounding boxes that overlap the ground truth boxes. Check RPN settings or keep training.') else: mean_overlapping_bboxes = float(sum(rpn_accuracy_rpn_monitor))/len(rpn_accuracy_rpn_monitor) print('[Test Accuracy] Mean number of bounding boxes from RPN overlapping ground truth boxes: {}'.format(mean_overlapping_bboxes)) return total_class_loss/ count_class + total_rpn_loss/ count_rpn
r_loss_rpn_cls = record_df['loss_rpn_cls'] r_loss_rpn_regr = record_df['loss_rpn_regr'] r_loss_class_cls = record_df['loss_class_cls'] r_loss_class_regr = record_df['loss_class_regr'] r_curr_loss = record_df['curr_loss'] r_elapsed_time = record_df['elapsed_time'] r_mAP = record_df['mAP'] print('Already train %dK batches' % (len(record_df))) optimizer = Adam(lr=1e-5) optimizer_classifier = Adam(lr=1e-5) model_rpn.compile( optimizer=optimizer, loss=[rpn_loss_cls(num_anchors), rpn_loss_regr(num_anchors)]) model_classifier.compile( optimizer=optimizer_classifier, loss=[class_loss_cls, class_loss_regr(len(classes_count) - 1)], metrics={'dense_class_{}'.format(len(classes_count)): 'accuracy'}) model_all.compile(optimizer='sgd', loss='mae') # Training setting total_epochs = len(record_df) r_epochs = len(record_df) epoch_length = 1000 num_epochs = 10 iter_num = 0
def train(epoch): print("\n\nTraining epoch {}\n\n".format(epoch)) rpn_accuracy_rpn_monitor = [] rpn_accuracy_for_epoch = [] regr_rpn_loss= 0 class_rpn_loss =0 total_rpn_loss = 0 regr_class_loss= 0 class_class_loss =0 total_class_loss = 0 count_rpn = 0 count_class = 0 for i,(image, boxes, labels , temp, num_pos) in enumerate(train_loader): count_rpn +=1 y_is_box_label = temp[0].to(device=model_rpn_cuda) y_rpn_regr = temp[1].to(device=model_rpn_cuda) image = Variable(image).to(device=model_rpn_cuda) boxes = boxes base_x , cls_k , reg_k = model_rpn(image) l1 = rpn_loss_regr(y_true=y_rpn_regr, y_pred=reg_k , y_is_box_label=y_is_box_label , lambda_rpn_regr=args.lambda_rpn_regr , device=model_rpn_cuda) l2 = rpn_loss_cls_fixed_num(y_pred = cls_k , y_is_box_label= y_is_box_label , lambda_rpn_class = args.lambda_rpn_class) regr_rpn_loss += l1.item() class_rpn_loss += l2.item() loss = l1 + l2 total_rpn_loss += loss.item() optimizer_model_rpn.zero_grad() loss.backward() optimizer_model_rpn.step() with torch.no_grad(): base_x , cls_k , reg_k = model_rpn(image) cls_k = cls_k.to(device=model_classifier_cuda) reg_k = reg_k.to(device=model_classifier_cuda) base_x = base_x.to(device=model_classifier_cuda) for b in range(args.train_batch): img_data = {} with torch.no_grad(): # Convert rpn layer to roi bboxes # cls_k.shape : b, h, w, 9 # reg_k : b, h, w, 36 # model_classifier_cuda # cls_k[b,:].shape == [13, 10, 9] # reg_k[b,:].shape == [13, 10, 36] # num_anchors = 9 # all_possible_anchor_boxes_tensor.shape == [4, 13, 10, 9] rpn_rois = rpn_to_roi(cls_k[b,:].cpu(), reg_k[b,:].cpu(), no_anchors=num_anchors, all_possible_anchor_boxes=all_possible_anchor_boxes_tensor.cpu().clone() ) rpn_rois.to(device=model_classifier_cuda) # can't concatenate batch # no of boxes may vary across the batch img_data["boxes"] = boxes[b].to(device=model_classifier_cuda) // downscale img_data['labels'] = labels[b] # rpn_rois : 300, 4 # img_data["boxes"].shape : 68,4 # len(img_data["labels"]) : 68 # X2 are qualified anchor boxes from model_rpn (converted anochors) # Y1 are the label, Y1[-1] is the background bounding box (negative bounding box), ambigous (neutral boxes are eliminated < min overlap thresold) # Y2 is concat of 1 , tx, ty, tw, th and 0, tx, ty, tw, th X2, Y1, Y2, _ = calc_iou(rpn_rois, img_data, class_mapping=config.label_map ) X2 = X2.to(device=model_classifier_cuda) Y1 = Y1.to(device=model_classifier_cuda) Y2 = Y2.to(device=model_classifier_cuda) # If X2 is None means there are no matching bboxes if X2 is None: rpn_accuracy_rpn_monitor.append(0) rpn_accuracy_for_epoch.append(0) continue neg_samples = torch.where(Y1[:, -1] == 1)[0] pos_samples = torch.where(Y1[:, -1] == 0)[0] rpn_accuracy_rpn_monitor.append(pos_samples.size(0)) rpn_accuracy_for_epoch.append(pos_samples.size(0)) db = Dataset_roi(pos=pos_samples.cpu() , neg= neg_samples.cpu()) roi_loader = DataLoader(db, shuffle=True, batch_size=args.n_roi // 4, num_workers=args.workers//2, pin_memory=False, drop_last=False) # list(roi_loader) for j,potential_roi in enumerate(roi_loader): pos = potential_roi[0] neg = potential_roi[1] if type(pos) == list : rois = X2[neg] rpn_base = base_x[b].unsqueeze(0) Y11 = Y1[neg] Y22 = Y2[neg] # out_class : args.n_roi // 2 , # no of class elif type(neg) == list : rois = X2[pos] rpn_base = base_x[b].unsqueeze(0) #out_class : args.n_roi // 2 , # no of class Y11 = Y1[pos] Y22 = Y2[pos] else: ind = torch.cat([pos,neg]) rois = X2[ind] rpn_base = base_x[b].unsqueeze(0) #out_class: args.n_roi , # no of class Y11 = Y1[ind] Y22 = Y2[ind] # IF YOU ARE NOT SEEING THESE SHAPES THEN SOMETHING IS WRONG # Y11.shape = 20,8 # Y22.shape = 20,56 # out_class.shape = 20,8 # out_regr.shape = 20,56 # rois.shape = 20, 4 # rpn_base.shape = torch.Size([1, 2048, 50, 38]) count_class += 1 rois = Variable(rois).to(device=model_classifier_cuda) out_class , out_regr = model_classifier(base_x = rpn_base , rois= rois ) # torch.Size([5, 2048, 7, 7]) torch.Size([5, 4]) l3 = class_loss_cls(y_true=Y11, y_pred=out_class , lambda_cls_class=args.lambda_cls_class) l4 = class_loss_regr(y_true=Y22, y_pred= out_regr , lambda_cls_regr= args.lambda_cls_regr) regr_class_loss += l4.item() class_class_loss += l3.item() loss = l3 + l4 total_class_loss += loss.item() optimizer_classifier.zero_grad() loss.backward() optimizer_classifier.step() if count_class % args.display_class == 0 : if count_class == 0 : print('[Classifier] RPN Ex: {}-th ,Batch : {}, Anchor Box: {}-th, Classifier Model Classification loss: {} Regression loss: {} Total Loss: {}'.format(i,b,j,0,0,0)) else: print('[Classifier] RPN Ex: {}-th ,Batch : {}, Anchor Box: {}-th, Classifier Model Classification loss: {} Regression loss: {} Total Loss: {} '.format(i,b,j, class_class_loss / count_class, regr_class_loss / count_class ,total_class_loss/ count_class )) if i % args.display_rpn == 0 : if len(rpn_accuracy_rpn_monitor) == 0 : print('[RPN] RPN is not producing bounding boxes that overlap the ground truth boxes. Check RPN settings or keep training.') else: mean_overlapping_bboxes = float(sum(rpn_accuracy_rpn_monitor))/len(rpn_accuracy_rpn_monitor) print('[RPN] Mean number of bounding boxes from RPN overlapping ground truth boxes: {}'.format(mean_overlapping_bboxes)) print('[RPN] RPN Ex: {}-th RPN Model Classification loss: {} Regression loss: {} Total Loss: {} '.format(i ,class_rpn_loss / count_rpn, regr_rpn_loss / count_rpn ,total_rpn_loss/ count_rpn )) print("-- END OF EPOCH -- {}".format(epoch)) print("------------------------------" ) print('[RPN] RPN Ex: {}-th RPN Model Classification loss: {} Regression loss: {} Total Loss: {} '.format(i ,class_rpn_loss / count_rpn, regr_rpn_loss / count_rpn ,total_rpn_loss/ count_rpn )) if count_class == 0 : print('[Classifier] RPN Ex: {}-th ,Batch : {}, Anchor Box: {}-th, Classifier Model Classification loss: {} Regression loss: {} Total Loss: {}'.format(i,b,j,0,0,0)) else: print('[Classifier] RPN Ex: {}-th ,Batch : {}, Anchor Box: {}-th, Classifier Model Classification loss: {} Regression loss: {} Total Loss: {} '.format(i,b,j, class_class_loss / count_class, regr_class_loss / count_class ,total_class_loss/ count_class )) if len(rpn_accuracy_rpn_monitor) == 0 : print('[RPN] RPN is not producing bounding boxes that overlap the ground truth boxes. Check RPN settings or keep training.') else: mean_overlapping_bboxes = float(sum(rpn_accuracy_rpn_monitor))/len(rpn_accuracy_rpn_monitor) print('[RPN] Mean number of bounding boxes from RPN overlapping ground truth boxes: {}'.format(mean_overlapping_bboxes)) print('Total Loss {}'.format( total_class_loss/ count_class + total_rpn_loss/ count_rpn )) print("------------------------------" )
######################################################################################################################################################################################################################## # Loss ######################################################################################################################################################################################################################## from loss import rpn_loss_regr, rpn_loss_cls_fixed_num import torch y_rpn_regr = torch.rand(1, 10, 20, 36) pred = torch.rand(1, 10, 20, 36) y_is_box_label = torch.rand(1, 10, 20, 9) y_is_box_label = (y_is_box_label > 0.66).float() * 1 + (y_is_box_label < 0.33).float() * -1 l1 = rpn_loss_regr(y_true=y_rpn_regr, y_pred=pred, y_is_box_label=y_is_box_label) pred = torch.rand(1, 10, 20, 9) l2 = rpn_loss_cls_fixed_num(y_pred=pred, y_is_box_label=y_is_box_label) ######################################################################################################################################################################################################################## # default_anchors ######################################################################################################################################################################################################################## import math from tools import default_anchors import torch height = 800 width = 600