def train(epoch):
    print("\n\nTraining epoch {}\n\n".format(epoch))
    rpn_accuracy_rpn_monitor = []
    rpn_accuracy_for_epoch = []
    
    regr_rpn_loss= 0 
    class_rpn_loss =0 
    total_rpn_loss = 0 

    regr_class_loss= 0
    class_class_loss =0  
    total_class_loss = 0     

    count_rpn  = 0 
    count_class = 0 


    for i,(image, boxes, labels , temp, num_pos) in enumerate(train_loader):
        count_rpn +=1

        y_is_box_label = temp[0].to(device=model_rpn_cuda)
        y_rpn_regr = temp[1].to(device=model_rpn_cuda)
        image = Variable(image).to(device=model_rpn_cuda)
        boxes = boxes
        base_x , cls_k , reg_k = model_rpn(image)
        
        l1 = rpn_loss_regr(y_true=y_rpn_regr, y_pred=reg_k , y_is_box_label=y_is_box_label , lambda_rpn_regr=args.lambda_rpn_regr , device=model_rpn_cuda)
        l2 = rpn_loss_cls_fixed_num(y_pred = cls_k , y_is_box_label= y_is_box_label , lambda_rpn_class = args.lambda_rpn_class)
        
        regr_rpn_loss += l1.item() 
        class_rpn_loss += l2.item() 
        loss = l1 + l2 
        total_rpn_loss += loss.item()
        
        optimizer_model_rpn.zero_grad()
        loss.backward()
        optimizer_model_rpn.step()                        

        with torch.no_grad():

            base_x , cls_k , reg_k = model_rpn(image)
            
            cls_k = cls_k.to(device=model_classifier_cuda)
            reg_k = reg_k.to(device=model_classifier_cuda)
            base_x = base_x.to(device=model_classifier_cuda)
            

        for b in range(args.train_batch):
            img_data = {}
            with torch.no_grad():
                # Convert rpn layer to roi bboxes
                # cls_k.shape : b, h, w, 9
                # reg_k : b, h, w, 36
                # model_classifier_cuda
                # cls_k[b,:].shape == [13, 10, 9]
                # reg_k[b,:].shape == [13, 10, 36]
                # num_anchors = 9
                # all_possible_anchor_boxes_tensor.shape == [4, 13, 10, 9]
                
                rpn_rois = rpn_to_roi(cls_k[b,:].cpu(), reg_k[b,:].cpu(), no_anchors=num_anchors,  all_possible_anchor_boxes=all_possible_anchor_boxes_tensor.cpu().clone() )
                rpn_rois.to(device=model_classifier_cuda)
                # can't concatenate batch 
                # no of boxes may vary across the batch 
                img_data["boxes"] = boxes[b].to(device=model_classifier_cuda) // downscale
                img_data['labels'] = labels[b]

                # rpn_rois : 300, 4
                # img_data["boxes"].shape : 68,4
                # len(img_data["labels"]) : 68

                # X2 are qualified anchor boxes from model_rpn (converted anochors)
                # Y1 are the label, Y1[-1] is the background bounding box (negative bounding box), ambigous (neutral boxes are eliminated < min overlap thresold)
                # Y2 is concat of 1 , tx, ty, tw, th and 0, tx, ty, tw, th 
                X2, Y1, Y2, _ = calc_iou(rpn_rois, img_data, class_mapping=config.label_map )
                
                X2 = X2.to(device=model_classifier_cuda)
                Y1 = Y1.to(device=model_classifier_cuda)
                Y2 = Y2.to(device=model_classifier_cuda)
            
                # If X2 is None means there are no matching bboxes
                if X2 is None:
                    rpn_accuracy_rpn_monitor.append(0)
                    rpn_accuracy_for_epoch.append(0)
                    continue

                neg_samples = torch.where(Y1[:, -1] == 1)[0]
                pos_samples = torch.where(Y1[:, -1] == 0)[0]
                rpn_accuracy_rpn_monitor.append(pos_samples.size(0))
                rpn_accuracy_for_epoch.append(pos_samples.size(0))
            
            db = Dataset_roi(pos=pos_samples.cpu() , neg= neg_samples.cpu())
            roi_loader = DataLoader(db, shuffle=True,  
                batch_size=args.n_roi // 4, num_workers=args.workers//2, pin_memory=False, drop_last=False)
            # list(roi_loader)
            for j,potential_roi in enumerate(roi_loader):
                pos = potential_roi[0]
                neg = potential_roi[1]
                if type(pos) == list :
                    rois = X2[neg]
                    rpn_base = base_x[b].unsqueeze(0)
                    Y11 = Y1[neg]
                    Y22 = Y2[neg]
                    # out_class : args.n_roi // 2 , # no of class
                elif type(neg) == list :
                    rois = X2[pos]
                    rpn_base = base_x[b].unsqueeze(0)
                    #out_class :  args.n_roi // 2 , # no of class
                    Y11 = Y1[pos]
                    Y22 = Y2[pos]
                else:
                    ind = torch.cat([pos,neg])
                    rois = X2[ind]
                    rpn_base = base_x[b].unsqueeze(0)
                    #out_class:  args.n_roi , # no of class
                    Y11 = Y1[ind]
                    Y22 = Y2[ind]
                
                # IF YOU ARE NOT SEEING THESE SHAPES THEN SOMETHING IS WRONG
                # Y11.shape  = 20,8
                # Y22.shape = 20,56
                # out_class.shape = 20,8
                # out_regr.shape = 20,56
                # rois.shape = 20, 4 
                # rpn_base.shape  = torch.Size([1, 2048, 50, 38])
                count_class += 1
                rois = Variable(rois).to(device=model_classifier_cuda)
                out_class , out_regr = model_classifier(base_x = rpn_base , rois= rois )
                # torch.Size([5, 2048, 7, 7]) torch.Size([5, 4])

                l3 = class_loss_cls(y_true=Y11, y_pred=out_class , lambda_cls_class=args.lambda_cls_class)
                l4 = class_loss_regr(y_true=Y22, y_pred= out_regr , lambda_cls_regr= args.lambda_cls_regr)

                regr_class_loss += l4.item()
                class_class_loss += l3.item()   

                loss = l3 + l4 
                total_class_loss += loss.item()
                
                
                optimizer_classifier.zero_grad()
                loss.backward()
                optimizer_classifier.step()

                if count_class % args.display_class == 0 :
                    if count_class == 0 :
                        print('[Classifier] RPN Ex: {}-th ,Batch : {}, Anchor Box: {}-th, Classifier Model Classification loss: {} Regression loss: {} Total Loss: {}'.format(i,b,j,0,0,0))
                    else:
                        print('[Classifier] RPN Ex: {}-th ,Batch : {}, Anchor Box: {}-th, Classifier Model Classification loss: {} Regression loss: {} Total Loss: {} '.format(i,b,j, class_class_loss / count_class, regr_class_loss / count_class ,total_class_loss/ count_class ))

        if i % args.display_rpn == 0 :
            if len(rpn_accuracy_rpn_monitor) == 0 :
                print('[RPN] RPN is not producing bounding boxes that overlap the ground truth boxes. Check RPN settings or keep training.')
            else:
                mean_overlapping_bboxes = float(sum(rpn_accuracy_rpn_monitor))/len(rpn_accuracy_rpn_monitor)
                print('[RPN] Mean number of bounding boxes from RPN overlapping ground truth boxes: {}'.format(mean_overlapping_bboxes)) 
            print('[RPN] RPN Ex: {}-th RPN Model Classification loss: {} Regression loss: {} Total Loss: {} '.format(i ,class_rpn_loss / count_rpn, regr_rpn_loss / count_rpn ,total_rpn_loss/ count_rpn ))

    print("-- END OF EPOCH -- {}".format(epoch)) 
    print("------------------------------" ) 
    print('[RPN] RPN Ex: {}-th RPN Model Classification loss: {} Regression loss: {} Total Loss: {} '.format(i ,class_rpn_loss / count_rpn, regr_rpn_loss / count_rpn ,total_rpn_loss/ count_rpn ))
    if count_class == 0 :
        print('[Classifier] RPN Ex: {}-th ,Batch : {}, Anchor Box: {}-th, Classifier Model Classification loss: {} Regression loss: {} Total Loss: {}'.format(i,b,j,0,0,0))
    else:
        print('[Classifier] RPN Ex: {}-th ,Batch : {}, Anchor Box: {}-th, Classifier Model Classification loss: {} Regression loss: {} Total Loss: {} '.format(i,b,j, class_class_loss / count_class, regr_class_loss / count_class ,total_class_loss/ count_class ))
    if len(rpn_accuracy_rpn_monitor) == 0 :
        print('[RPN] RPN is not producing bounding boxes that overlap the ground truth boxes. Check RPN settings or keep training.')
    else:
        mean_overlapping_bboxes = float(sum(rpn_accuracy_rpn_monitor))/len(rpn_accuracy_rpn_monitor)
        print('[RPN] Mean number of bounding boxes from RPN overlapping ground truth boxes: {}'.format(mean_overlapping_bboxes)) 
    print('Total Loss  {}'.format(  total_class_loss/ count_class + total_rpn_loss/ count_rpn ))
    print("------------------------------" ) 
def test(epoch):
    print("================================")
    print("Testing after epoch {}".format(epoch))
    print("================================")
    rpn_accuracy_rpn_monitor = []
    rpn_accuracy_for_epoch = []
    
    regr_rpn_loss= 0 
    class_rpn_loss =0 
    total_rpn_loss = 0 

    regr_class_loss= 0
    class_class_loss =0  
    total_class_loss = 0     

    count_rpn  = 0 
    count_class = 0 

    total_count = 0

    for i,(image, boxes, labels , temp, num_pos) in enumerate(test_loader):

        count_rpn +=1
        
        y_is_box_label = temp[0].to(device=model_rpn_cuda)
        y_rpn_regr = temp[1].to(device=model_rpn_cuda)
        image = Variable(image).to(device=model_rpn_cuda)

        base_x , cls_k , reg_k = model_rpn(image)
        l1 = rpn_loss_regr(y_true=y_rpn_regr, y_pred=reg_k , y_is_box_label=y_is_box_label , lambda_rpn_regr=args.lambda_rpn_regr, device=model_rpn_cuda)
        l2 = rpn_loss_cls_fixed_num(y_pred = cls_k , y_is_box_label= y_is_box_label , lambda_rpn_class = args.lambda_rpn_class)
        
        regr_rpn_loss += l1.item() 
        class_rpn_loss += l2.item() 
        loss = l1 + l2 
        total_rpn_loss += loss.item()
        
        base_x , cls_k , reg_k = model_rpn(image)

        for b in range(image.size(0)):
            img_data = {}
            rpn_rois = rpn_to_roi(cls_k[b,:].cpu(), reg_k[b,:].cpu(), no_anchors=num_anchors,  all_possible_anchor_boxes=all_possible_anchor_boxes_tensor.cpu().clone() )
            rpn_rois.to(device=model_classifier_cuda)

            img_data["boxes"] = boxes[b].to(device=model_classifier_cuda) // downscale
            img_data['labels'] = labels[b]

            X2, Y1, Y2, _ = calc_iou(rpn_rois, img_data, class_mapping=config.label_map )
            
            if X2 is None:
                rpn_accuracy_rpn_monitor.append(0)
                rpn_accuracy_for_epoch.append(0)
                continue

            X2 = X2.to(device=model_classifier_cuda)
            Y1 = Y1.to(device=model_classifier_cuda)
            Y2 = Y2.to(device=model_classifier_cuda)

            count_class += 1 
            rpn_base = base_x[b].unsqueeze(0)
            out_class , out_regr = model_classifier(base_x = rpn_base , rois= X2 )
            
            l3 = class_loss_cls(y_true=Y1, y_pred=out_class , lambda_cls_class=args.lambda_cls_class)
            l4 = class_loss_regr(y_true=Y2, y_pred= out_regr , lambda_cls_regr= args.lambda_cls_regr)
            loss = l3 + l4 
            class_class_loss += l3.item()   
            regr_class_loss += l4.item()
            total_class_loss += loss.item()

            if args.save_evaluations :
                total_count += 1
                if total_count % 100 == 0 :
                    predicted_boxes = X2
                    predicted_boxes[:,2]  = predicted_boxes[:,2]  + predicted_boxes[:,0]
                    predicted_boxes[:,3]  = predicted_boxes[:,3]  + predicted_boxes[:,1]
                    predicted_boxes = predicted_boxes * downscale
                    
                    temp_img = (denormalize['std'].cpu() * image[b].cpu()) + denormalize['mean'].cpu()  
                    save_evaluations_image(image=temp_img, boxes=predicted_boxes, labels=Y1, count=total_count, config=config , save_dir=args.save_dir)
    
    if count_class == 0 :
        count_class = 1
        total_class_loss = 0 
        print('[Test Accuracy] Classifier Model Classification loss: {} Regression loss: {} Total Loss: {} '.format(0, 0 ,0 ))
    else:
        print('[Test Accuracy] Classifier Model Classification loss: {} Regression loss: {} Total Loss: {} '.format(class_class_loss / count_class, regr_class_loss / count_class ,total_class_loss/ count_class ))

    print('[Test Accuracy] RPN Model Classification loss: {} Regression loss: {} Total Loss: {} '.format(class_rpn_loss / count_rpn, regr_rpn_loss / count_rpn ,total_rpn_loss/ count_rpn ))
    if len(rpn_accuracy_rpn_monitor) == 0 :
        print('[Test Accuracy] RPN is not producing bounding boxes that overlap the ground truth boxes. Check RPN settings or keep training.')        
    else:
        mean_overlapping_bboxes = float(sum(rpn_accuracy_rpn_monitor))/len(rpn_accuracy_rpn_monitor)
        print('[Test Accuracy] Mean number of bounding boxes from RPN overlapping ground truth boxes: {}'.format(mean_overlapping_bboxes)) 
    
    return total_class_loss/ count_class + total_rpn_loss/ count_rpn
Ejemplo n.º 3
0
    r_curr_loss = record_df['curr_loss']
    r_elapsed_time = record_df['elapsed_time']
    r_mAP = record_df['mAP']

    print('Already train %dK batches' % (len(record_df)))

    optimizer = Adam(lr=1e-5)
    optimizer_classifier = Adam(lr=1e-5)
    model_rpn.compile(
        optimizer=optimizer,
        loss=[rpn_loss_cls(num_anchors),
              rpn_loss_regr(num_anchors)])
    model_classifier.compile(
        optimizer=optimizer_classifier,
        loss=[class_loss_cls,
              class_loss_regr(len(classes_count) - 1)],
        metrics={'dense_class_{}'.format(len(classes_count)): 'accuracy'})
    model_all.compile(optimizer='sgd', loss='mae')

    # Training setting
    total_epochs = len(record_df)
    r_epochs = len(record_df)

    epoch_length = 1000
    num_epochs = 10
    iter_num = 0

    total_epochs += num_epochs

    losses = np.zeros((epoch_length, 5))
    rpn_accuracy_rpn_monitor = []