Exemplo n.º 1
0
def fit_one_epoch(model_rpn, model_all, epoch, epoch_size, epoch_size_val, gen,
                  genval, Epoch, callback):
    total_loss = 0
    rpn_loc_loss = 0
    rpn_cls_loss = 0
    roi_loc_loss = 0
    roi_cls_loss = 0

    val_toal_loss = 0
    with tqdm(total=epoch_size,
              desc=f'Epoch {epoch + 1}/{Epoch}',
              postfix=dict,
              mininterval=0.3) as pbar:
        for iteration, batch in enumerate(gen):
            if iteration >= epoch_size:
                break
            X, Y, boxes = batch[0], batch[1], batch[2]
            P_rpn = model_rpn.predict_on_batch(X)

            height, width, _ = np.shape(X[0])
            base_feature_width, base_feature_height = get_img_output_length(
                width, height)
            anchors = get_anchors([base_feature_width, base_feature_height],
                                  width, height)
            results = bbox_util.detection_out_rpn(P_rpn, anchors)

            roi_inputs = []
            out_classes = []
            out_regrs = []
            for i in range(len(X)):
                R = results[i][:, 1:]
                X2, Y1, Y2 = calc_iou(R, config, boxes[i], NUM_CLASSES)
                roi_inputs.append(X2)
                out_classes.append(Y1)
                out_regrs.append(Y2)

            loss_class = model_all.train_on_batch(
                [X, np.array(roi_inputs)],
                [Y[0], Y[1],
                 np.array(out_classes),
                 np.array(out_regrs)])

            write_log(callback, [
                'total_loss', 'rpn_cls_loss', 'rpn_reg_loss',
                'detection_cls_loss', 'detection_reg_loss'
            ], loss_class, iteration)

            rpn_cls_loss += loss_class[1]
            rpn_loc_loss += loss_class[2]
            roi_cls_loss += loss_class[3]
            roi_loc_loss += loss_class[4]
            total_loss = rpn_loc_loss + rpn_cls_loss + roi_loc_loss + roi_cls_loss

            pbar.set_postfix(
                **{
                    'total': total_loss / (iteration + 1),
                    'rpn_cls': rpn_cls_loss / (iteration + 1),
                    'rpn_loc': rpn_loc_loss / (iteration + 1),
                    'roi_cls': roi_cls_loss / (iteration + 1),
                    'roi_loc': roi_loc_loss / (iteration + 1),
                    'lr': K.get_value(model_rpn.optimizer.lr)
                })
            pbar.update(1)

    print('Start Validation')
    with tqdm(total=epoch_size_val,
              desc=f'Epoch {epoch + 1}/{Epoch}',
              postfix=dict,
              mininterval=0.3) as pbar:
        for iteration, batch in enumerate(genval):
            if iteration >= epoch_size_val:
                break
            X, Y, boxes = batch[0], batch[1], batch[2]
            P_rpn = model_rpn.predict_on_batch(X)

            height, width, _ = np.shape(X[0])
            base_feature_width, base_feature_height = get_img_output_length(
                width, height)
            anchors = get_anchors([base_feature_width, base_feature_height],
                                  width, height)
            results = bbox_util.detection_out_rpn(P_rpn, anchors)

            roi_inputs = []
            out_classes = []
            out_regrs = []
            for i in range(len(X)):
                R = results[i][:, 1:]
                X2, Y1, Y2 = calc_iou(R, config, boxes[i], NUM_CLASSES)
                roi_inputs.append(X2)
                out_classes.append(Y1)
                out_regrs.append(Y2)

            loss_class = model_all.test_on_batch(
                [X, np.array(roi_inputs)],
                [Y[0], Y[1],
                 np.array(out_classes),
                 np.array(out_regrs)])

            val_toal_loss += loss_class[0]
            pbar.set_postfix(**{'total': val_toal_loss / (iteration + 1)})
            pbar.update(1)

    loss_history.append_loss(total_loss / (epoch_size + 1),
                             val_toal_loss / (epoch_size_val + 1))
    print('Finish Validation')
    print('Epoch:' + str(epoch + 1) + '/' + str(Epoch))
    print('Total Loss: %.4f || Val Loss: %.4f ' %
          (total_loss / (epoch_size + 1), val_toal_loss /
           (epoch_size_val + 1)))

    print('Saving state, iter:', str(epoch + 1))
    model_all.save_weights('logs/Epoch%d-Total_Loss%.4f-Val_Loss%.4f.h5' %
                           ((epoch + 1), total_loss /
                            (epoch_size + 1), val_toal_loss /
                            (epoch_size_val + 1)))
    return
Exemplo n.º 2
0
                print(
                    'Average number of overlapping bounding boxes from RPN = {} for {} previous iterations'
                    .format(mean_overlapping_bboxes, EPOCH_LENGTH))
                if mean_overlapping_bboxes == 0:
                    print(
                        'RPN is not producing bounding boxes that overlap the ground truth boxes. Check RPN settings or keep training.'
                    )

            X, Y, boxes = batch[0], batch[1], batch[2]

            loss_rpn = model_rpn.train_on_batch(X, Y)
            write_log(callback, ['rpn_cls_loss', 'rpn_reg_loss'], loss_rpn,
                      train_step)
            P_rpn = model_rpn.predict_on_batch(X)
            height, width, _ = np.shape(X[0])
            anchors = get_anchors(get_img_output_length(width, height), width,
                                  height)

            # 将预测结果进行解码
            results = bbox_util.detection_out(P_rpn,
                                              anchors,
                                              1,
                                              confidence_threshold=0)

            R = results[0][:, 2:]

            X2, Y1, Y2, IouS = calc_iou(R, config, boxes[0], width, height,
                                        NUM_CLASSES)

            if X2 is None:
                rpn_accuracy_rpn_monitor.append(0)