Beispiel #1
0
def train_kitti():
    # config for data argument
    cfg = config.Config()

    cfg.use_horizontal_flips = True
    cfg.use_vertical_flips = True
    cfg.rot_90 = True
    cfg.num_rois = 32
    cfg.base_net_weights = os.path.join('./model/', nn.get_weight_path())

    # TODO: the only file should to be change for other data to train
    cfg.model_path = './model/kitti_frcnn_last.hdf5'
    cfg.simple_label_file = 'kitti_simple_label.txt'

    all_images, classes_count, class_mapping = get_data(cfg.simple_label_file)

    if 'bg' not in classes_count:
        classes_count['bg'] = 0
        class_mapping['bg'] = len(class_mapping)

    cfg.class_mapping = class_mapping
    with open(cfg.config_save_file, 'wb') as config_f:
        pickle.dump(cfg, config_f)
        print('Config has been written to {}, and can be loaded when testing to ensure correct results'.format(
            cfg.config_save_file))

    inv_map = {v: k for k, v in class_mapping.items()}

    print('Training images per class:')
    pprint.pprint(classes_count)
    print('Num classes (including bg) = {}'.format(len(classes_count)))
    random.shuffle(all_images)
    num_imgs = len(all_images)
    train_imgs = [s for s in all_images if s['imageset'] == 'trainval']
    val_imgs = [s for s in all_images if s['imageset'] == 'test']

    print('Num train samples {}'.format(len(train_imgs)))
    print('Num val samples {}'.format(len(val_imgs)))

    data_gen_train = data_generators.get_anchor_gt(train_imgs, classes_count, cfg, nn.get_img_output_length,
                                                   K.image_dim_ordering(), mode='train')
    data_gen_val = data_generators.get_anchor_gt(val_imgs, classes_count, cfg, nn.get_img_output_length,
                                                 K.image_dim_ordering(), mode='val')

    if K.image_dim_ordering() == 'th':
        input_shape_img = (3, None, None)
    else:
        input_shape_img = (None, None, 3)

    img_input = Input(shape=input_shape_img)
    roi_input = Input(shape=(None, 4))

    # define the base network (resnet here, can be VGG, Inception, etc)
    shared_layers = nn.nn_base(img_input, trainable=True)

    # define the RPN, built on the base layers
    num_anchors = len(cfg.anchor_box_scales) * len(cfg.anchor_box_ratios)
    rpn = nn.rpn(shared_layers, num_anchors)

    classifier = nn.classifier(shared_layers, roi_input, cfg.num_rois, nb_classes=len(classes_count), trainable=True)

    model_rpn = Model(img_input, rpn[:2])
    model_classifier = Model([img_input, roi_input], classifier)

    # this is a model that holds both the RPN and the classifier, used to load/save weights for the models
    model_all = Model([img_input, roi_input], rpn[:2] + classifier)

    try:
        print('loading weights from {}'.format(cfg.base_net_weights))
        model_rpn.load_weights(cfg.model_path, by_name=True)
        model_classifier.load_weights(cfg.model_path, by_name=True)
    except Exception as e:
        print(e)
        print('Could not load pretrained model weights. Weights can be found in the keras application folder '
              'https://github.com/fchollet/keras/tree/master/keras/applications')

    optimizer = Adam(lr=1e-5)
    optimizer_classifier = Adam(lr=1e-5)
    model_rpn.compile(optimizer=optimizer,
                      loss=[losses_fn.rpn_loss_cls(num_anchors), losses_fn.rpn_loss_regr(num_anchors)])
    model_classifier.compile(optimizer=optimizer_classifier,
                             loss=[losses_fn.class_loss_cls, losses_fn.class_loss_regr(len(classes_count) - 1)],
                             metrics={'dense_class_{}'.format(len(classes_count)): 'accuracy'})
    model_all.compile(optimizer='sgd', loss='mae')

    epoch_length = 1000
    num_epochs = int(cfg.num_epochs)
    iter_num = 0

    losses = np.zeros((epoch_length, 5))
    rpn_accuracy_rpn_monitor = []
    rpn_accuracy_for_epoch = []
    start_time = time.time()

    best_loss = np.Inf

    class_mapping_inv = {v: k for k, v in class_mapping.items()}
    print('Starting training')

    vis = True

    for epoch_num in range(num_epochs):

        progbar = generic_utils.Progbar(epoch_length)
        print('Epoch {}/{}'.format(epoch_num + 1, num_epochs))

        while True:
            try:

                if len(rpn_accuracy_rpn_monitor) == epoch_length and cfg.verbose:
                    mean_overlapping_bboxes = float(sum(rpn_accuracy_rpn_monitor)) / len(rpn_accuracy_rpn_monitor)
                    rpn_accuracy_rpn_monitor = []
                    print(
                        'Average number of overlapping bounding boxes from RPN = {} for {} previous iterations'.format(
                            mean_overlapping_bboxes, epoch_length))
                    if mean_overlapping_bboxes == 0:
                        print('RPN is not producing bounding boxes that overlap'
                              ' the ground truth boxes. Check RPN settings or keep training.')

                X, Y, img_data = next(data_gen_train)

                loss_rpn = model_rpn.train_on_batch(X, Y)

                P_rpn = model_rpn.predict_on_batch(X)

                result = roi_helpers.rpn_to_roi(P_rpn[0], P_rpn[1], cfg, K.image_dim_ordering(), use_regr=True,
                                                overlap_thresh=0.7,
                                                max_boxes=300)
                # note: calc_iou converts from (x1,y1,x2,y2) to (x,y,w,h) format
                X2, Y1, Y2, IouS = roi_helpers.calc_iou(result, img_data, cfg, class_mapping)

                if X2 is None:
                    rpn_accuracy_rpn_monitor.append(0)
                    rpn_accuracy_for_epoch.append(0)
                    continue

                neg_samples = np.where(Y1[0, :, -1] == 1)
                pos_samples = np.where(Y1[0, :, -1] == 0)

                if len(neg_samples) > 0:
                    neg_samples = neg_samples[0]
                else:
                    neg_samples = []

                if len(pos_samples) > 0:
                    pos_samples = pos_samples[0]
                else:
                    pos_samples = []

                rpn_accuracy_rpn_monitor.append(len(pos_samples))
                rpn_accuracy_for_epoch.append((len(pos_samples)))

                if cfg.num_rois > 1:
                    if len(pos_samples) < cfg.num_rois // 2:
                        selected_pos_samples = pos_samples.tolist()
                    else:
                        selected_pos_samples = np.random.choice(pos_samples, cfg.num_rois // 2, replace=False).tolist()
                    try:
                        selected_neg_samples = np.random.choice(neg_samples, cfg.num_rois - len(selected_pos_samples),
                                                                replace=False).tolist()
                    except:
                        selected_neg_samples = np.random.choice(neg_samples, cfg.num_rois - len(selected_pos_samples),
                                                                replace=True).tolist()

                    sel_samples = selected_pos_samples + selected_neg_samples
                else:
                    # in the extreme case where num_rois = 1, we pick a random pos or neg sample
                    selected_pos_samples = pos_samples.tolist()
                    selected_neg_samples = neg_samples.tolist()
                    if np.random.randint(0, 2):
                        sel_samples = random.choice(neg_samples)
                    else:
                        sel_samples = random.choice(pos_samples)

                loss_class = model_classifier.train_on_batch([X, X2[:, sel_samples, :]],
                                                             [Y1[:, sel_samples, :], Y2[:, sel_samples, :]])

                losses[iter_num, 0] = loss_rpn[1]
                losses[iter_num, 1] = loss_rpn[2]

                losses[iter_num, 2] = loss_class[1]
                losses[iter_num, 3] = loss_class[2]
                losses[iter_num, 4] = loss_class[3]

                iter_num += 1

                progbar.update(iter_num,
                               [('rpn_cls', np.mean(losses[:iter_num, 0])), ('rpn_regr', np.mean(losses[:iter_num, 1])),
                                ('detector_cls', np.mean(losses[:iter_num, 2])),
                                ('detector_regr', np.mean(losses[:iter_num, 3]))])

                if iter_num == epoch_length:
                    loss_rpn_cls = np.mean(losses[:, 0])
                    loss_rpn_regr = np.mean(losses[:, 1])
                    loss_class_cls = np.mean(losses[:, 2])
                    loss_class_regr = np.mean(losses[:, 3])
                    class_acc = np.mean(losses[:, 4])

                    mean_overlapping_bboxes = float(sum(rpn_accuracy_for_epoch)) / len(rpn_accuracy_for_epoch)
                    rpn_accuracy_for_epoch = []

                    if cfg.verbose:
                        print('Mean number of bounding boxes from RPN overlapping ground truth boxes: {}'.format(
                            mean_overlapping_bboxes))
                        print('Classifier accuracy for bounding boxes from RPN: {}'.format(class_acc))
                        print('Loss RPN classifier: {}'.format(loss_rpn_cls))
                        print('Loss RPN regression: {}'.format(loss_rpn_regr))
                        print('Loss Detector classifier: {}'.format(loss_class_cls))
                        print('Loss Detector regression: {}'.format(loss_class_regr))
                        print('Elapsed time: {}'.format(time.time() - start_time))

                    curr_loss = loss_rpn_cls + loss_rpn_regr + loss_class_cls + loss_class_regr
                    iter_num = 0
                    start_time = time.time()

                    if curr_loss < best_loss:
                        if cfg.verbose:
                            print('Total loss decreased from {} to {}, saving weights'.format(best_loss, curr_loss))
                        best_loss = curr_loss
                        model_all.save_weights(cfg.model_path)

                    break

            except Exception as e:
                print('Exception: {}'.format(e))
                # save model
                model_all.save_weights(cfg.model_path)
                continue
    print('Training complete, exiting.')
Beispiel #2
0
        C.network = 'resnet50'
        from keras_frcnn import resnet as nn
elif options.network == 'vgg':
        from keras_frcnn import vgg as nn
        C.network = 'vgg'
else:
        print('Not a valid model')
        raise ValueError


# check if weight path was passed via command line
if options.input_weight_path:
        C.base_net_weights = options.input_weight_path
else:
        # set the path to weights based on backend and model
        C.base_net_weights = nn.get_weight_path()

train_imgs, classes_count, class_mapping = get_data(options.train_path,'train')
val_imgs, _, _ = get_data(options.train_path,'val')

if 'bg' not in classes_count:
        classes_count['bg'] = 0
        class_mapping['bg'] = len(class_mapping)

C.class_mapping = class_mapping

inv_map = {v: k for k, v in class_mapping.items()}

print('Training images per class:')
pprint.pprint(classes_count)
print(f'Num classes (including bg) = {len(classes_count)}')
def test_kitti():
    # config for data argument
    cfg = config.Config()

    cfg.use_horizontal_flips = True
    cfg.use_vertical_flips = True
    cfg.rot_90 = True
    cfg.num_rois = 32
    cfg.base_net_weights = os.path.join('./model/', nn.get_weight_path())

    # TODO: the only file should to be change for other data to train
    cfg.model_path = './model/kitti_frcnn_last.hdf5'

    cfg.simple_label_file = 'kitti_simple_label.txt'
    #查看绝对路径
    #t = os.path.abspath('kitti_simple_label.txt')

    all_images, classes_count, class_mapping = get_data(cfg.simple_label_file)
    pedestrain_num = classes_count['Pedestrian']
    if 'bg' not in classes_count:
        classes_count['bg'] = 0
        class_mapping['bg'] = len(class_mapping)

    cfg.class_mapping = class_mapping
    with open(cfg.config_save_file, 'wb') as config_f:
        pickle.dump(cfg, config_f)
        print(
            'Config has been written to {}, and can be loaded when testing to ensure correct results'
            .format(cfg.config_save_file))

    inv_map = {v: k for k, v in class_mapping.items()}

    print('Training images per class:')
    pprint.pprint(classes_count)
    print('Num classes (including bg) = {}'.format(len(classes_count)))
    random.shuffle(all_images)
    num_imgs = len(all_images)
    train_imgs = [s for s in all_images if s['imageset'] == 'trainval']
    val_imgs = [s for s in all_images if s['imageset'] == 'test']

    print('Num train samples {}'.format(len(train_imgs)))
    print('Num val samples {}'.format(len(val_imgs)))

    data_gen_train = data_generators.get_anchor_gt(train_imgs,
                                                   classes_count,
                                                   cfg,
                                                   nn.get_img_output_length,
                                                   K.image_dim_ordering(),
                                                   mode='train')
    data_gen_val = data_generators.get_anchor_gt(val_imgs,
                                                 classes_count,
                                                 cfg,
                                                 nn.get_img_output_length,
                                                 K.image_dim_ordering(),
                                                 mode='val')

    if K.image_dim_ordering() == 'th':
        input_shape_img = (3, None, None)
    else:
        input_shape_img = (None, None, 3)
    #img_input:  三通道,为输入图片
    img_input = Input(shape=input_shape_img)
    #roi_input:为输入图片boudingbox的四维值
    roi_input = Input(shape=(None, 4))

    # define the base network (resnet here, can be VGG, Inception, etc)
    #shared_layers : 基础的网络结构(例如: resnet,vgg)通过该网络来提取原始图片的featuremap特征,最后将这些特征送入RPN网络和RCNN网络
    # 1.定义nn的输入层,faster-rcnn共享卷积层,
    shared_layers = nn.nn_base(img_input, trainable=True)

    # define the RPN, built on the base layers 2.定义RPN层
    num_anchors = len(cfg.anchor_box_scales) * len(cfg.anchor_box_ratios)
    #RPN网络用于生成region proposals,该层通过sigmoid函数判断anchors属于foreground或者background, 再利用bounding box regression修正anchors获得修正后的RoI。
    # rpn: 在基础的网络结构使用9个bounding box产生了分类和回归的rpn网络。定义rpn层,return [x_class, x_regr, base_layers]
    rpn = nn.rpn(shared_layers, num_anchors)
    #定义分类器层,定义classifier的输入和输出
    classifier = nn.classifier(shared_layers,
                               roi_input,
                               cfg.num_rois,
                               nb_classes=len(classes_count),
                               trainable=True)
    #定义rpn模型的输入和输出一个框2分类(最后使用的sigmod而不是softmax)和框的回归
    model_rpn = Model(img_input, rpn[:2])
    #定义classifier的输入和输出
    model_classifier = Model([img_input, roi_input], classifier)

    # this is a model that holds both the RPN and the classifier, used to load/save weights for the models
    model_all = Model([img_input, roi_input], rpn[:2] + classifier)

    try:
        print('loading weights from {}'.format(cfg.base_net_weights))
        #TODO 第一次运行因为model_path没有hdf5文件,因此修改为cfg.base_net_weights,现在可以修改回来
        model_rpn.load_weights(cfg.model_path, by_name=True)
        model_classifier.load_weights(cfg.model_path, by_name=True)
        # model_rpn.load_weights(cfg.base_net_weights, by_name=True)
        # model_classifier.load_weights(cfg.base_net_weights, by_name=True)
    except Exception as e:
        print(e)
        print(
            'Could not load pretrained model weights. Weights can be found in the keras application folder '
            'https://github.com/fchollet/keras/tree/master/keras/applications')

    optimizer = Adam(lr=1e-5)
    optimizer_classifier = Adam(lr=1e-5)
    model_rpn.compile(optimizer=optimizer,
                      loss=[
                          losses_fn.rpn_loss_cls(num_anchors),
                          losses_fn.rpn_loss_regr(num_anchors)
                      ])
    model_classifier.compile(
        optimizer=optimizer_classifier,
        loss=[
            losses_fn.class_loss_cls,
            losses_fn.class_loss_regr(len(classes_count) - 1)
        ],
        metrics={'dense_class_{}'.format(len(classes_count)): 'accuracy'})
    model_all.compile(optimizer='sgd', loss='mae')

    #todo 增加tensorboard日志文件
    log_path = './graph'
    callback = TensorBoard(log_path,
                           histogram_freq=0,
                           write_graph=True,
                           write_images=True)
    callback.set_model(model_all)

    #todo epoch的大小为训练图片的个数
    # epoch_length = len(train_imgs)
    epoch_length = len(val_imgs)
    #epoch_length = 47182
    num_epochs = int(cfg.num_epochs)
    iter_num = 0

    losses = np.zeros((epoch_length, 5))
    #todo
    losses_val = np.zeros((epoch_length, 5))

    rpn_accuracy_rpn_monitor = []
    rpn_accuracy_for_epoch = []
    rpn_accuracy_rpn_monitor_val = []
    rpn_accuracy_for_epoch_val = []
    start_time = time.time()

    best_loss = np.Inf

    class_mapping_inv = {v: k for k, v in class_mapping.items()}
    print('Starting testing')

    vis = True

    allbbox = 0

    #只有训练的,改成只有测试的
    for epoch_num in range(num_epochs):

        progbar = generic_utils.Progbar(epoch_length)
        print('Epoch {}/{}'.format(epoch_num + 1, num_epochs))

        while True:
            try:

                if len(rpn_accuracy_rpn_monitor
                       ) == epoch_length and cfg.verbose:
                    mean_overlapping_bboxes = float(
                        sum(rpn_accuracy_rpn_monitor)) / len(
                            rpn_accuracy_rpn_monitor)
                    rpn_accuracy_rpn_monitor = []
                    print(
                        'Average number of overlapping bounding boxes from RPN = {} for {} previous iterations'
                        .format(mean_overlapping_bboxes, epoch_length))
                    if mean_overlapping_bboxes == 0:
                        print(
                            'RPN is not producing bounding boxes that overlap'
                            ' the ground truth boxes. Check RPN settings or keep training.'
                        )
                #todo train修改为val
                # X, Y, img_data = next(data_gen_train)
                X, Y, img_data = next(data_gen_val)

                # loss_rpn = model_rpn.train_on_batch(X, Y)
                loss_rpn = model_rpn.test_on_batch(X, Y)

                P_rpn = model_rpn.predict_on_batch(X)

                result = roi_helpers.rpn_to_roi(P_rpn[0],
                                                P_rpn[1],
                                                cfg,
                                                K.image_dim_ordering(),
                                                use_regr=True,
                                                overlap_thresh=0.7,
                                                max_boxes=300)
                # note: calc_iou converts from (x1,y1,x2,y2) to (x,y,w,h) format

                #todo 增加 count
                # X2, Y1, Y2, IouS = roi_helpers.calc_iou(result, img_data, cfg, class_mapping)
                X2, Y1, Y2, IouS, count = roi_helpers.calc_iou(
                    result, img_data, cfg, class_mapping)
                allbbox = allbbox + count
                if X2 is None:
                    rpn_accuracy_rpn_monitor.append(0)
                    rpn_accuracy_for_epoch.append(0)
                    continue

                neg_samples = np.where(Y1[0, :, -1] == 1)
                pos_samples = np.where(Y1[0, :, -1] == 0)

                if len(neg_samples) > 0:
                    neg_samples = neg_samples[0]
                else:
                    neg_samples = []

                if len(pos_samples) > 0:
                    pos_samples = pos_samples[0]
                else:
                    pos_samples = []

                rpn_accuracy_rpn_monitor.append(len(pos_samples))
                rpn_accuracy_for_epoch.append((len(pos_samples)))

                if cfg.num_rois > 1:
                    if len(pos_samples) < cfg.num_rois // 2:
                        selected_pos_samples = pos_samples.tolist()
                    else:
                        selected_pos_samples = np.random.choice(
                            pos_samples, cfg.num_rois // 2,
                            replace=False).tolist()
                    try:
                        selected_neg_samples = np.random.choice(
                            neg_samples,
                            cfg.num_rois - len(selected_pos_samples),
                            replace=False).tolist()
                    except:
                        selected_neg_samples = np.random.choice(
                            neg_samples,
                            cfg.num_rois - len(selected_pos_samples),
                            replace=True).tolist()

                    sel_samples = selected_pos_samples + selected_neg_samples
                else:
                    # in the extreme case where num_rois = 1, we pick a random pos or neg sample
                    selected_pos_samples = pos_samples.tolist()
                    selected_neg_samples = neg_samples.tolist()
                    if np.random.randint(0, 2):
                        sel_samples = random.choice(neg_samples)
                    else:
                        sel_samples = random.choice(pos_samples)

                loss_class = model_classifier.train_on_batch(
                    [X, X2[:, sel_samples, :]],
                    [Y1[:, sel_samples, :], Y2[:, sel_samples, :]])

                losses[iter_num, 0] = loss_rpn[1]
                losses[iter_num, 1] = loss_rpn[2]

                losses[iter_num, 2] = loss_class[1]
                losses[iter_num, 3] = loss_class[2]
                losses[iter_num, 4] = loss_class[3]

                iter_num += 1

                progbar.update(
                    iter_num,
                    [('rpn_cls', np.mean(losses[:iter_num, 0])),
                     ('rpn_regr', np.mean(losses[:iter_num, 1])),
                     ('detector_cls', np.mean(losses[:iter_num, 2])),
                     ('detector_regr', np.mean(losses[:iter_num, 3]))])

                if iter_num == epoch_length:
                    loss_rpn_cls = np.mean(losses[:, 0])
                    loss_rpn_regr = np.mean(losses[:, 1])
                    loss_class_cls = np.mean(losses[:, 2])
                    loss_class_regr = np.mean(losses[:, 3])
                    class_acc = np.mean(losses[:, 4])

                    mean_overlapping_bboxes = float(sum(
                        rpn_accuracy_for_epoch)) / len(rpn_accuracy_for_epoch)
                    rpn_accuracy_for_epoch = []

                    if cfg.verbose:
                        print(
                            'Mean number of bounding boxes from RPN overlapping ground truth boxes: {}'
                            .format(mean_overlapping_bboxes))
                        print(
                            'Classifier accuracy for bounding boxes from RPN: {}'
                            .format(class_acc))
                        print('Loss RPN classifier: {}'.format(loss_rpn_cls))
                        print('Loss RPN regression: {}'.format(loss_rpn_regr))
                        print('Loss Detector classifier: {}'.format(
                            loss_class_cls))
                        print('Loss Detector regression: {}'.format(
                            loss_class_regr))
                        print('Elapsed time: {}'.format(time.time() -
                                                        start_time))

                    curr_loss = loss_rpn_cls + loss_rpn_regr + loss_class_cls + loss_class_regr
                    iter_num = 0
                    start_time = time.time()

                    if curr_loss < best_loss:
                        if cfg.verbose:
                            print(
                                'Total loss decreased from {} to {}, saving weights'
                                .format(best_loss, curr_loss))
                        best_loss = curr_loss
                        model_all.save_weights(cfg.model_path)

                    break

            except Exception as e:
                print('Exception: {}'.format(e))
                # save model
                model_all.save_weights(cfg.model_path)
                continue
    print("检测准确率:")
    print(float(allbbox / pedestrain_num))
    print('testing complete, exiting.')
Beispiel #4
0
def train_net():
    # config for data argument
    cfg = config.Config()

    cfg.use_horizontal_flips = False
    cfg.use_vertical_flips = False
    cfg.rot_90 = False
    cfg.num_rois = 32  # config中设置的是4
    cfg.base_net_weights = os.path.join('./model/', nn.get_weight_path())

    # TODO: the only file should to be change for other data to train
    cfg.model_path = 'samples.hdf5'

    cfg.simple_label_file = 'annotations_train.txt' # 训练集产生的标签

    all_images, classes_count, class_mapping = get_data(cfg.simple_label_file)

    if 'bg' not in classes_count:
        classes_count['bg'] = 0
        class_mapping['bg'] = len(class_mapping)

    cfg.class_mapping = class_mapping
    with open(cfg.config_save_file, 'wb') as config_f:
        pickle.dump(cfg, config_f)
        print('Config has been written to {}, and can be loaded when testing to ensure correct results'.format(
            cfg.config_save_file))

    inv_map = {v: k for k, v in class_mapping.items()}

    print('Training images per class:')
    pprint.pprint(classes_count)
    print('Num classes (including bg) = {}'.format(len(classes_count)))
    random.shuffle(all_images)
    num_imgs = len(all_images)
    train_imgs = [s for s in all_images if s['imageset'] == 'trainval']
    val_imgs = [s for s in all_images if s['imageset'] == 'test']

    print('Num train samples {}'.format(len(train_imgs)))
    print('Num val samples {}'.format(len(val_imgs)))

    # there图片
    data_gen_train = data_generators.get_anchor_gt(train_imgs, classes_count, cfg, nn.get_img_output_length,
                                                   K.image_dim_ordering(), mode='train')

    data_gen_val = data_generators.get_anchor_gt(val_imgs, classes_count, cfg, nn.get_img_output_length,
                                                 K.image_dim_ordering(), mode='val')

    if K.image_dim_ordering() == 'th':
        input_shape_img = (3, None, None)
    else:
        input_shape_img = (None, None, 3)

    img_input = Input(shape=input_shape_img)
    roi_input = Input(shape=(None, 4))

    # define the base network (resnet here, can be VGG, Inception, etc)
    shared_layers = nn.nn_base(img_input, trainable=True)

    # define the RPN, built on the base layers
    num_anchors = len(cfg.anchor_box_scales) * len(cfg.anchor_box_ratios)
    rpn = nn.rpn(shared_layers, num_anchors)
    # classifier是什么?
    # classes_count {} 每一个类的数量:{'cow': 4, 'dog': 10, ...}
    # C.num_rois每次取的感兴趣区域,默认为32
    # roi_input = Input(shape=(None, 4)) 框框
    # classifier是faster rcnn的两个损失函数[out_class, out_reg]
    # shared_layers是vgg的输出feature map
    classifier = nn.classifier(shared_layers, roi_input, cfg.num_rois, nb_classes=len(classes_count), trainable=True)
    # 定义model_rpn
    model_rpn = Model(img_input, rpn[:2])
    model_classifier = Model([img_input, roi_input], classifier)

    # this is a model that holds both the RPN and the classifier, used to load/save weights for the models
    model_all = Model([img_input, roi_input], rpn[:2] + classifier)

    try:
        print('loading weights from {}'.format(cfg.base_net_weights))
        model_rpn.load_weights(cfg.model_path, by_name=True)
        model_classifier.load_weights(cfg.model_path, by_name=True)
    except Exception as e:
        print(e)
        print('Could not load pretrained model weights. Weights can be found in the keras application folder '
              'https://github.com/fchollet/keras/tree/master/keras/applications')

    optimizer = Adam(lr=1e-5)
    optimizer_classifier = Adam(lr=1e-5)
    model_rpn.compile(optimizer=optimizer,
                      loss=[losses_fn.rpn_loss_cls(num_anchors), losses_fn.rpn_loss_regr(num_anchors)])
    model_classifier.compile(optimizer=optimizer_classifier,
                             loss=[losses_fn.class_loss_cls, losses_fn.class_loss_regr(len(classes_count) - 1)],
                             metrics={'dense_class_{}'.format(len(classes_count)): 'accuracy'})
    model_all.compile(optimizer='sgd', loss='mae')

    epoch_length = 10
    num_epochs = int(cfg.num_epochs)
    iter_num = 0

    losses = np.zeros((epoch_length, 5))
    rpn_accuracy_rpn_monitor = []
    rpn_accuracy_for_epoch = []
    start_time = time.time()

    best_loss = np.Inf

    class_mapping_inv = {v: k for k, v in class_mapping.items()}
    print('Starting training')

    vis = True

    for epoch_num in range(num_epochs):

        progbar = generic_utils.Progbar(epoch_length)
        print('Epoch {}/{}'.format(epoch_num + 1, num_epochs))

        while True:
            try:
                # 用来监督每一次epoch的平均正回归框的个数
                if len(rpn_accuracy_rpn_monitor) == epoch_length and cfg.verbose:
                    mean_overlapping_bboxes = float(sum(rpn_accuracy_rpn_monitor)) / len(rpn_accuracy_rpn_monitor)
                    rpn_accuracy_rpn_monitor = []
                    print(
                        'Average number of overlapping bounding boxes from RPN = {} for {} previous iterations'.format(
                            mean_overlapping_bboxes, epoch_length))
                    if mean_overlapping_bboxes == 0:
                        # 每次都框不到正样本,说明rpn有问题
                        print('RPN is not producing bounding boxes that overlap'
                              ' the ground truth boxes. Check RPN settings or keep training.')

                # 迭代器,取数据
                # 训练rpn网络,X是图片,Y是对应类别和回归梯度(不是所有的点都参加训练,符合条件才参加训练)
                # next(data_gen_train)是一个迭代器。
                # 返回的是 np.copy(x_img), [np.copy(y_rpn_cls), np.copy(y_rpn_regr)],
                # img_data_aug(我们这里假设数据没有进行水平翻转等操作。那么,x_img = img_data_aug),
                # y_rpn_cls和y_rpn_regr是RPN的两个损失函数。
                X, Y, img_data = next(data_gen_train)


                # classifer和rpn网络交叉训练
                loss_rpn = model_rpn.train_on_batch(X, Y)
                P_rpn = model_rpn.predict_on_batch(X)

                # result是得到的预选框
                # 得到了region proposals,接下来另一个重要的思想就是ROI pooling,
                # 可将不同shape的特征图转化为固定shape,送到全连接层进行最终的预测。
                # rpn_to_roi接收的是每张图片的预测输出,返回的R = [boxes, probs]
                # ---------------------
                result = roi_helpers.rpn_to_roi(P_rpn[0], P_rpn[1], cfg, K.image_dim_ordering(), use_regr=True,
                                                overlap_thresh=0.7,
                                                max_boxes=300)

                # note: calc_iou converts from (x1,y1,x2,y2) to (x,y,w,h) format
                # Y1根据预选框,得到这个预选框属于哪一类,
                # Y2这个类相应的回归梯度
                # X2是返回这个框
                """
                # 通过calc_iou()找出剩下的不多的region对应ground truth里重合度最高的bbox,从而获得model_classifier的数据和标签。
                # X2保留所有的背景和match bbox的框; Y1 是类别one-hot转码; Y2是对应类别的标签及回归要学习的坐标位置; IouS是debug用的。
                """
                X2, Y1, Y2, IouS = roi_helpers.calc_iou(result, img_data, cfg, class_mapping)

                if X2 is None:
                    # 如果没有有效的预选框则结束本次循环
                    rpn_accuracy_rpn_monitor.append(0)
                    rpn_accuracy_for_epoch.append(0)
                    continue

                # 因为是one—hot,最后一位是1,则代表是背景
                neg_samples = np.where(Y1[0, :, -1] == 1)
                pos_samples = np.where(Y1[0, :, -1] == 0)

                if len(neg_samples) > 0:
                    neg_samples = neg_samples[0] # 将其变为1维的数组
                else:
                    neg_samples = []

                if len(pos_samples) > 0:
                    pos_samples = pos_samples[0]
                else:
                    pos_samples = []

                rpn_accuracy_rpn_monitor.append(len(pos_samples))
                rpn_accuracy_for_epoch.append((len(pos_samples)))

                if cfg.num_rois > 1:
                    # 选择num_rois个数的框,送入classifier网络进行训练。 分类网络一次要训练多少个框
                    # 思路:当num_rois大于1的时候正负样本尽量取到一半,小于1的时候正负样本随机取一个。
                    if len(pos_samples) < cfg.num_rois // 2:
                        # 挑选正样本
                        selected_pos_samples = pos_samples.tolist()
                    else:
                        selected_pos_samples = np.random.choice(pos_samples, cfg.num_rois // 2, replace=False).tolist()
                    try:
                        # 挑选负样本
                        selected_neg_samples = np.random.choice(neg_samples, cfg.num_rois - len(selected_pos_samples),
                                                                replace=False).tolist()
                    except:
                        selected_neg_samples = np.random.choice(neg_samples, cfg.num_rois - len(selected_pos_samples),
                                                                replace=True).tolist()

                    sel_samples = selected_pos_samples + selected_neg_samples
                else:
                    # in the extreme case where num_rois = 1, we pick a random pos or neg sample
                    selected_pos_samples = pos_samples.tolist()
                    selected_neg_samples = neg_samples.tolist()
                    if np.random.randint(0, 2):
                        sel_samples = random.choice(neg_samples)
                    else:
                        sel_samples = random.choice(pos_samples)

                # 训练classifier网络
                # 是从位置中挑选,
                loss_class = model_classifier.train_on_batch([X, X2[:, sel_samples, :]],
                                                             [Y1[:, sel_samples, :], Y2[:, sel_samples, :]])

                #
                losses[iter_num, 0] = loss_rpn[1] # rpn_cls平均值
                losses[iter_num, 1] = loss_rpn[2] # rpn_regr平均值

                losses[iter_num, 2] = loss_class[1] # detector_cls平均值
                losses[iter_num, 3] = loss_class[2] # detector_regr平均值
                losses[iter_num, 4] = loss_class[3] # 4是准确率

                iter_num += 1

                # 进度条更新
                progbar.update(iter_num,
                               [('rpn_cls', np.mean(losses[:iter_num, 0])), ('rpn_regr', np.mean(losses[:iter_num, 1])),
                                ('detector_cls', np.mean(losses[:iter_num, 2])),
                                ('detector_regr', np.mean(losses[:iter_num, 3]))])

                if iter_num == epoch_length:
                    loss_rpn_cls = np.mean(losses[:, 0])  # loss中存放了每一次训练出的losses
                    loss_rpn_regr = np.mean(losses[:, 1])
                    loss_class_cls = np.mean(losses[:, 2])
                    loss_class_regr = np.mean(losses[:, 3])
                    class_acc = np.mean(losses[:, 4])

                    mean_overlapping_bboxes = float(sum(rpn_accuracy_for_epoch)) / len(rpn_accuracy_for_epoch)
                    rpn_accuracy_for_epoch = []

                    if cfg.verbose:
                        # 打印出前n次loss的平均值
                        print('Mean number of bounding boxes from RPN overlapping ground truth boxes: {}'.format(
                            mean_overlapping_bboxes))
                        print('Classifier accuracy for bounding boxes from RPN: {}'.format(class_acc))
                        print('Loss RPN classifier: {}'.format(loss_rpn_cls))
                        print('Loss RPN regression: {}'.format(loss_rpn_regr))
                        print('Loss Detector classifier: {}'.format(loss_class_cls))
                        print('Loss Detector regression: {}'.format(loss_class_regr))
                        print('Elapsed time: {}'.format(time.time() - start_time))

                    curr_loss = loss_rpn_cls + loss_rpn_regr + loss_class_cls + loss_class_regr
                    iter_num = 0
                    start_time = time.time()

                    if curr_loss < best_loss:
                        # 当结束一轮的epoch时,只有当这轮epoch的loss小于最优的时候才会存储这轮的训练数据,
                        # 并结束这轮epoch进入下一轮epoch。
                        if cfg.verbose:
                            print('Total loss decreased from {} to {}, saving weights'.format(best_loss, curr_loss))
                        best_loss = curr_loss
                        model_all.save_weights(cfg.model_path)

                    break

            except Exception as e:
                print('Exception: {}'.format(e))
                # save model
                model_all.save_weights(cfg.model_path)
                continue
    print('Training complete, exiting.')
Beispiel #5
0
def train_mscoco():
    # ===========================模型的配置和加载======================================
    # config for data argument
    cfg = config.Config()
    cfg.use_horizontal_flips = True
    cfg.use_vertical_flips = True
    cfg.rot_90 = True
    cfg.num_rois = 32
    #resnet前四卷积部分的权值
    cfg.base_net_weights = nn.get_weight_path()
    #保存模型的权重值
    cfg.model_path = './model/mscoco_frcnn.hdf5'
    #all_images, class_mapping = get_data()
    #加载训练的图片
    train_imgs, class_mapping = get_data('train')

    cfg.class_mapping = class_mapping
    print('Num classes (including bg) = {}'.format(len(class_mapping)))
    #保存所有的配置文件
    with open(cfg.config_save_file, 'wb') as config_f:
        pickle.dump(cfg, config_f)
        print(
            'Config has been written to {}, and can be loaded when testing to ensure correct results'
            .format(cfg.config_save_file))
    #图片随机洗牌
    random.shuffle(train_imgs)
    print('Num train samples {}'.format(len(train_imgs)))
    data_gen_train = data_generators.get_anchor_gt(train_imgs,
                                                   class_mapping,
                                                   cfg,
                                                   nn.get_img_output_length,
                                                   K.image_dim_ordering(),
                                                   mode='train')
    # ==============================================================================

    # ===============================模型的定义======================================
    #keras内核为tensorflow
    input_shape_img = (None, None, 3)
    img_input = Input(shape=input_shape_img)
    roi_input = Input(shape=(None, 4))
    # define the base resnet50 network
    shared_layers = nn.nn_base(img_input, trainable=False)
    # define the RPN, built on the base layers
    num_anchors = len(cfg.anchor_box_scales) * len(cfg.anchor_box_ratios)
    rpn = nn.rpn(shared_layers, num_anchors)
    classifier = nn.classifier(shared_layers,
                               roi_input,
                               cfg.num_rois,
                               nb_classes=len(class_mapping),
                               trainable=True)
    #model(input=,output=)
    model_rpn = Model(img_input, rpn[:2])
    model_classifier = Model([img_input, roi_input], classifier)
    # this is a model that holds both the RPN and the classifier, used to load/save weights for the models
    model_all = Model([img_input, roi_input], rpn[:2] + classifier)
    # ==============================================================================

    # ===========================基本模型加载ImageNet权值=============================
    try:
        print('loading base model weights from {}'.format(
            cfg.base_net_weights))
        model_rpn.load_weights(cfg.base_net_weights, by_name=True)
        model_classifier.load_weights(cfg.base_net_weights, by_name=True)
    except Exception as e:
        print('基本模型加载ImageNet权值: ', e)
        print('Could not load pretrained model weights on ImageNet.')
    # ==============================================================================

    # ===============================模型优化========================================
    #在调用model.compile()之前初始化一个优化器对象,然后传入该函数
    optimizer = Adam(lr=1e-5)
    optimizer_classifier = Adam(lr=1e-5)
    model_rpn.compile(optimizer=optimizer,
                      loss=[
                          losses_fn.rpn_loss_cls(num_anchors),
                          losses_fn.rpn_loss_regr(num_anchors)
                      ])
    model_classifier.compile(
        optimizer=optimizer_classifier,
        loss=[
            losses_fn.class_loss_cls,
            losses_fn.class_loss_regr(len(class_mapping) - 1)
        ],
        metrics={'dense_class_{}'.format(len(class_mapping)): 'accuracy'})
    model_all.compile(optimizer='sgd', loss='mae')
    # ==============================================================================

    # ================================训练、输出设置==================================
    epoch_length = len(train_imgs)
    num_epochs = int(cfg.num_epochs)
    iter_num = 0
    losses = np.zeros((epoch_length, 5))
    rpn_accuracy_rpn_monitor = []
    rpn_accuracy_for_epoch = []
    start_time = time.time()
    best_loss = np.Inf

    logger = Logger(os.path.join('.', 'log.txt'))
    # ==============================================================================

    print('Starting training')
    for epoch_num in range(num_epochs):

        progbar = generic_utils.Progbar(epoch_length)
        logger.write('Epoch {}/{}'.format(epoch_num + 1, num_epochs))

        while True:
            try:
                if len(rpn_accuracy_rpn_monitor
                       ) == epoch_length and cfg.verbose:
                    mean_overlapping_bboxes = float(
                        sum(rpn_accuracy_rpn_monitor)) / len(
                            rpn_accuracy_rpn_monitor)
                    rpn_accuracy_rpn_monitor = []
                    print(
                        'Average number of overlapping bounding boxes from RPN = {} for {} previous iterations'
                        .format(mean_overlapping_bboxes, epoch_length))
                    if mean_overlapping_bboxes == 0:
                        print(
                            'RPN is not producing bounding boxes that overlap'
                            ' the ground truth boxes. Check RPN settings or keep training.'
                        )
                #图片,标准的cls、rgr,盒子数据
                X, Y, img_data = next(data_gen_train)

                #训练rpn
                loss_rpn = model_rpn.train_on_batch(X, Y)

                #边训练rpn得到的区域送入roi
                #x_class, x_regr, base_layers
                P_rpn = model_rpn.predict_on_batch(X)

                result = roi_helpers.rpn_to_roi(P_rpn[0],
                                                P_rpn[1],
                                                cfg,
                                                K.image_dim_ordering(),
                                                use_regr=True,
                                                overlap_thresh=0.7,
                                                max_boxes=300)
                # note: calc_iou converts from (x1,y1,x2,y2) to (x,y,w,h) format
                #区域、cls、rgr、iou
                X2, Y1, Y2, IouS = roi_helpers.calc_iou(
                    result, img_data, cfg, class_mapping)

                if X2 is None:
                    rpn_accuracy_rpn_monitor.append(0)
                    rpn_accuracy_for_epoch.append(0)
                    continue

                neg_samples = np.where(Y1[0, :, -1] == 1)
                pos_samples = np.where(Y1[0, :, -1] == 0)

                if len(neg_samples) > 0:
                    neg_samples = neg_samples[0]
                else:
                    neg_samples = []

                if len(pos_samples) > 0:
                    pos_samples = pos_samples[0]
                else:
                    pos_samples = []

                rpn_accuracy_rpn_monitor.append(len(pos_samples))
                rpn_accuracy_for_epoch.append((len(pos_samples)))

                if cfg.num_rois > 1:
                    if len(pos_samples) < cfg.num_rois // 2:
                        selected_pos_samples = pos_samples.tolist()
                    else:
                        selected_pos_samples = np.random.choice(
                            pos_samples, cfg.num_rois // 2,
                            replace=False).tolist()
                    try:
                        selected_neg_samples = np.random.choice(
                            neg_samples,
                            cfg.num_rois - len(selected_pos_samples),
                            replace=False).tolist()
                    except:
                        selected_neg_samples = np.random.choice(
                            neg_samples,
                            cfg.num_rois - len(selected_pos_samples),
                            replace=True).tolist()

                    sel_samples = selected_pos_samples + selected_neg_samples
                else:
                    # in the extreme case where num_rois = 1, we pick a random pos or neg sample
                    selected_pos_samples = pos_samples.tolist()
                    selected_neg_samples = neg_samples.tolist()
                    if np.random.randint(0, 2):
                        sel_samples = random.choice(neg_samples)
                    else:
                        sel_samples = random.choice(pos_samples)

                #训练classifier
                loss_class = model_classifier.train_on_batch(
                    [X, X2[:, sel_samples, :]],
                    [Y1[:, sel_samples, :], Y2[:, sel_samples, :]])

                losses[iter_num, 0] = loss_rpn[1]
                losses[iter_num, 1] = loss_rpn[2]

                losses[iter_num, 2] = loss_class[1]
                losses[iter_num, 3] = loss_class[2]
                losses[iter_num, 4] = loss_class[3]

                iter_num += 1

                progbar.update(
                    iter_num,
                    [('rpn_cls', np.mean(losses[:iter_num, 0])),
                     ('rpn_regr', np.mean(losses[:iter_num, 1])),
                     ('detector_cls', np.mean(losses[:iter_num, 2])),
                     ('detector_regr', np.mean(losses[:iter_num, 3]))])

                if iter_num == epoch_length:
                    loss_rpn_cls = np.mean(losses[:, 0])
                    loss_rpn_regr = np.mean(losses[:, 1])
                    loss_class_cls = np.mean(losses[:, 2])
                    loss_class_regr = np.mean(losses[:, 3])
                    class_acc = np.mean(losses[:, 4])

                    mean_overlapping_bboxes = float(sum(
                        rpn_accuracy_for_epoch)) / len(rpn_accuracy_for_epoch)
                    rpn_accuracy_for_epoch = []

                    if cfg.verbose:
                        logger.write(
                            'Mean number of bounding boxes from RPN overlapping ground truth boxes: {}'
                            .format(mean_overlapping_bboxes))
                        logger.write(
                            'Classifier accuracy for bounding boxes from RPN: {}'
                            .format(class_acc))
                        logger.write(
                            'Loss RPN classifier: {}'.format(loss_rpn_cls))
                        logger.write(
                            'Loss RPN regression: {}'.format(loss_rpn_regr))
                        logger.write('Loss Detector classifier: {}'.format(
                            loss_class_cls))
                        logger.write('Loss Detector regression: {}'.format(
                            loss_class_regr))
                        logger.write('Elapsed time: {}'.format(time.time() -
                                                               start_time))

                    curr_loss = loss_rpn_cls + loss_rpn_regr + loss_class_cls + loss_class_regr
                    iter_num = 0
                    start_time = time.time()

                    if curr_loss < best_loss:
                        if cfg.verbose:
                            logger.write(
                                'Total loss decreased from {} to {}, saving weights'
                                .format(best_loss, curr_loss))
                        best_loss = curr_loss
                        model_all.save_weights(cfg.model_path)
                    break

            except Exception as e:
                print('Exception: {}'.format(e))
                # save model
                model_all.save_weights(cfg.model_path)
                continue
    print('Training complete, exiting.')
def train_kitti():
    # config for data argument
    cfg = config.Config()
    cfg.balanced_classes = True
    cfg.use_horizontal_flips = True
    cfg.use_vertical_flips = True
    cfg.rot_90 = True
    cfg.num_rois = 50  # 对于星图杯的光学遥感飞机检测,应该改为50+
    cfg.anchor_box_scales = [41, 70, 120, 20, 90]
    cfg.anchor_box_ratios = [[1, 1.4], [1, 0.84], [1, 1.17], [1, 0.64], [1, 1]]

    cfg.base_net_weights = os.path.join('./model/', nn.get_weight_path())

    # TODO: the only file should to be change for other data to train
    cfg.model_path = './model/kitti_frcnn_last.hdf5'
    cfg.simple_label_file = 'E:/Xingtubei/official_datas/OpticalAircraft/laptop_Chreoc_OpticalAircraft_bboxes.txt'  # '/media/liuhuaqing/Elements/Xingtubei/official_datas/OpticalAircraft/Chreoc_OpticalAircraft_bboxes.txt'#'F:/Xingtubei/official_datas/OpticalAircraft/Chreoc_OpticalAircraft_bboxes.txt' # 'kitti_simple_label.txt'

    all_images, classes_count, class_mapping = get_data(
        cfg.simple_label_file)  #读取数据集,cv2.imread()要求数据里不能有中文路径

    if 'bg' not in classes_count:  #'bg'应该是代表背景
        classes_count['bg'] = 0  # =0表示训练数据中没有“背景”这一类别
        class_mapping['bg'] = len(class_mapping)

    cfg.class_mapping = class_mapping
    with open(cfg.config_save_file, 'wb') as config_f:
        pickle.dump(cfg, config_f)
        print(
            'Config has been written to {}, and can be loaded when testing to ensure correct results'
            .format(cfg.config_save_file))

    inv_map = {v: k for k, v in class_mapping.items()}  #class_mapping的逆向map

    print('Training images per class:')
    pprint.pprint(classes_count)
    print('Num classes (including bg) = {}'.format(len(classes_count)))
    random.shuffle(all_images)
    num_imgs = len(all_images)
    train_imgs = [s for s in all_images
                  if s['imageset'] == 'trainval']  #训练集,列表形式,列表中的元素是字典
    val_imgs = [s for s in all_images
                if s['imageset'] == 'test']  #验证集,列表形式,列表中的元素是字典

    print('Num train samples {}'.format(len(train_imgs)))
    print('Num val samples {}'.format(len(val_imgs)))

    data_gen_train = data_generators.get_anchor_gt(
        train_imgs,
        classes_count,
        cfg,
        nn.get_img_output_length,
        K.image_dim_ordering(),
        mode='train')  #数据扩增,然后生成frcnn所需的训练数据(如:图片、rpn的梯度等等)
    data_gen_val = data_generators.get_anchor_gt(
        val_imgs,
        classes_count,
        cfg,
        nn.get_img_output_length,
        K.image_dim_ordering(),
        mode='val')  #数据扩增,然后生成frcnn所需的验证数据(如:图片、rpn的梯度等等)

    # 根据keras实际用的后端,定义相应的输入数据维度,因为两类后端的维度顺序不一样
    if K.image_dim_ordering() == 'th':
        input_shape_img = (3, None, None)  #当后端是thaneo
    else:
        input_shape_img = (None, None, 3)  #当后端是tensorflow

    img_input = Input(shape=input_shape_img)  # 输入图片
    roi_input = Input(shape=(None, 4))  # 输入人工标注的roi坐标,4表示x1,y1,x2,y2

    # define the base network (resnet here, can be VGG, Inception, etc)
    shared_layers = nn.nn_base(
        img_input,
        trainable=True)  # shared_layers是frcnn网络底部那些共享的层,在这里是ResNet。由nn定义好

    # define the RPN, built on the base layers
    num_anchors = len(cfg.anchor_box_scales) * len(cfg.anchor_box_ratios)
    rpn = nn.rpn(shared_layers, num_anchors)

    classifier = nn.classifier(shared_layers,
                               roi_input,
                               cfg.num_rois,
                               nb_classes=len(classes_count),
                               trainable=True)

    model_rpn = Model(
        img_input,
        rpn[:2])  #rpn网络由keras_frcnn/resnet定义好。rpn[:2]的前两个元素分别表示rpn网络的分类输出和回归输出
    model_classifier = Model([img_input, roi_input],
                             classifier)  #Keras的函数式模型为Model,即广义的拥有输入和输出的模型

    # this is a model that holds both the RPN and the classifier, used to load/save weights for the models
    model_all = Model([img_input, roi_input],
                      rpn[:2] + classifier)  #rpn[:2]+classifier的含义是??????

    try:
        # 尝试载入与训练网络权值
        print('loading weights from {}'.format(cfg.base_net_weights))
        model_rpn.load_weights(cfg.model_path, by_name=True)
        model_classifier.load_weights(cfg.model_path, by_name=True)
    except Exception as e:
        print(e)
        print(
            'Could not load pretrained model weights. Weights can be found in the keras application folder '
            'https://github.com/fchollet/keras/tree/master/keras/applications')

    optimizer = Adam(lr=1e-5)  # 定义一个Adam求解器,学习率lr
    optimizer_classifier = Adam(lr=1e-5)  # 定义一个Adam求解器,学习率lr
    # num_anchors等于9
    model_rpn.compile(optimizer=optimizer,
                      loss=[
                          losses_fn.rpn_loss_cls(num_anchors),
                          losses_fn.rpn_loss_regr(num_anchors)
                      ])
    model_classifier.compile(
        optimizer=optimizer_classifier,
        loss=[
            losses_fn.class_loss_cls,
            losses_fn.class_loss_regr(len(classes_count) - 1)
        ],
        metrics={'dense_class_{}'.format(len(classes_count)): 'accuracy'})
    model_all.compile(optimizer='sgd', loss='mae')

    epoch_length = 100  # 每迭代epoch_length次就检查一次是否要保存网络权值,然后重置iter_num = 0
    num_epochs = int(cfg.num_epochs)
    iter_num = 0  # 迭代次数的初值

    losses = np.zeros((epoch_length, 5))  # 初始化loss数组,记录每个周期的loss
    rpn_accuracy_rpn_monitor = []  # 初始化一个数组,记录rpn的训练过程中的精度变化
    rpn_accuracy_for_epoch = []  # 初始化一个数组,记录rpn的每个训练周期的的精度变化
    start_time = time.time()  # 开始训练的时间

    best_loss = np.Inf  # 改变量纪律训练以来最小的loss

    class_mapping_inv = {v: k
                         for k, v in class_mapping.items()
                         }  # class_mapping_inv是一个字典,key是目标类别编号,value是类别名称
    print('Starting training')

    vis = True

    for epoch_num in range(num_epochs):

        progbar = generic_utils.Progbar(epoch_length)  # 生成一个进度条对象
        print('Epoch {}/{}'.format(epoch_num + 1,
                                   num_epochs))  # 输出当前训练周期数/总周期数

        while True:  # 什么时候才结束这个循环?答:第247行的break(每迭代epoch_length次)
            try:

                if len(
                        rpn_accuracy_rpn_monitor
                ) == epoch_length and cfg.verbose:  # 每epoch_length次训练周期就在窗口显示一次RPN平均精度
                    mean_overlapping_bboxes = float(
                        sum(rpn_accuracy_rpn_monitor)) / len(
                            rpn_accuracy_rpn_monitor)
                    rpn_accuracy_rpn_monitor = []
                    print(
                        'Average number of overlapping bounding boxes from RPN = {} for {} previous iterations'
                        .format(mean_overlapping_bboxes, epoch_length))
                    if mean_overlapping_bboxes == 0:
                        print(
                            'RPN is not producing bounding boxes that overlap'
                            ' the ground truth boxes. Check RPN settings or keep training.'
                        )

                #X应该是图像,如kitti尺寸是(1,600,1987,3)。Y是label,img_data是字典,包含文件名、尺寸、人工标记的roi和类别等
                X, Y, img_data = next(data_gen_train)
                Y_1 = Y[0]
                Y_1 = Y_1[0, :, :, :]

                loss_rpn = model_rpn.train_on_batch(
                    X, Y)  #为什么Y的尺寸与P_rpn的尺寸不同?为什么loss_rpn的尺寸是3,含义是什么,在哪里定义的?

                P_rpn = model_rpn.predict_on_batch(
                    X)  #P_rpn的尺寸是(1, 124, 38, 9) (1, 124, 38, 36)

                result = roi_helpers.rpn_to_roi(
                    P_rpn[0],
                    P_rpn[1],
                    cfg,
                    K.image_dim_ordering(),
                    use_regr=True,
                    overlap_thresh=0.7,
                    max_boxes=300)  #result的尺寸是300*4
                # note: calc_iou converts from (x1,y1,x2,y2) to (x,y,w,h) format
                # X2的尺寸是100*4,Y1的尺寸是1*100*8(8=训练集中目标类别总数),IouS尺寸是100
                X2, Y1, Y2, IouS = roi_helpers.calc_iou(
                    result, img_data, cfg, class_mapping
                )  #Y2的尺寸是1*1*56,56=28*2,(28=4*7)前28是coords,后28是labels(是该类别则标1)

                if X2 is None:
                    rpn_accuracy_rpn_monitor.append(0)
                    rpn_accuracy_for_epoch.append(0)
                    continue

                neg_samples = np.where(
                    Y1[0, :, -1] == 1)  #Y1的尺寸是1*1*8表示分类预测结果,最后一个元素为1表示是背景
                pos_samples = np.where(Y1[0, :, -1] == 0)

                if len(neg_samples) > 0:
                    neg_samples = neg_samples[0]
                else:
                    neg_samples = []

                if len(pos_samples) > 0:
                    pos_samples = pos_samples[0]
                else:
                    pos_samples = []

                rpn_accuracy_rpn_monitor.append(len(pos_samples))
                rpn_accuracy_for_epoch.append((len(pos_samples)))

                if cfg.num_rois > 1:
                    if len(pos_samples) < cfg.num_rois // 2:
                        selected_pos_samples = pos_samples.tolist()
                    else:
                        selected_pos_samples = np.random.choice(
                            pos_samples, cfg.num_rois // 2,
                            replace=False).tolist()
                    try:
                        selected_neg_samples = np.random.choice(
                            neg_samples,
                            cfg.num_rois - len(selected_pos_samples),
                            replace=False).tolist()
                    except:
                        selected_neg_samples = np.random.choice(
                            neg_samples,
                            cfg.num_rois - len(selected_pos_samples),
                            replace=True).tolist()

                    sel_samples = selected_pos_samples + selected_neg_samples
                else:
                    # in the extreme case where num_rois = 1, we pick a random pos or neg sample
                    selected_pos_samples = pos_samples.tolist()
                    selected_neg_samples = neg_samples.tolist()
                    if np.random.randint(0, 2):
                        sel_samples = random.choice(neg_samples)
                    else:
                        sel_samples = random.choice(pos_samples)

                loss_class = model_classifier.train_on_batch(
                    [X, X2[:, sel_samples, :]],
                    [Y1[:, sel_samples, :], Y2[:, sel_samples, :]
                     ])  #用rpn输出的roi输入给classifier

                losses[iter_num, 0] = loss_rpn[1]
                losses[iter_num, 1] = loss_rpn[2]

                losses[iter_num, 2] = loss_class[1]
                losses[iter_num, 3] = loss_class[2]
                losses[iter_num, 4] = loss_class[3]

                iter_num += 1

                progbar.update(
                    iter_num,
                    [('rpn_cls', np.mean(losses[:iter_num, 0])),
                     ('rpn_regr', np.mean(losses[:iter_num, 1])),
                     ('detector_cls', np.mean(losses[:iter_num, 2])),
                     ('detector_regr', np.mean(losses[:iter_num, 3]))])

                if iter_num == epoch_length:  # 每迭代epoch_length次就检查一次是否要保存网络权值,然后重置iter_num = 0
                    loss_rpn_cls = np.mean(losses[:, 0])
                    loss_rpn_regr = np.mean(losses[:, 1])
                    loss_class_cls = np.mean(losses[:, 2])
                    loss_class_regr = np.mean(losses[:, 3])
                    class_acc = np.mean(losses[:, 4])

                    mean_overlapping_bboxes = float(sum(
                        rpn_accuracy_for_epoch)) / len(rpn_accuracy_for_epoch)
                    rpn_accuracy_for_epoch = []

                    if cfg.verbose:
                        print(
                            'Mean number of bounding boxes from RPN overlapping ground truth boxes: {}'
                            .format(mean_overlapping_bboxes))
                        print(
                            'Classifier accuracy for bounding boxes from RPN: {}'
                            .format(class_acc))
                        print('Loss RPN classifier: {}'.format(loss_rpn_cls))
                        print('Loss RPN regression: {}'.format(loss_rpn_regr))
                        print('Loss Detector classifier: {}'.format(
                            loss_class_cls))
                        print('Loss Detector regression: {}'.format(
                            loss_class_regr))
                        print('Elapsed time: {}'.format(time.time() -
                                                        start_time))

                    curr_loss = loss_rpn_cls + loss_rpn_regr + loss_class_cls + loss_class_regr
                    iter_num = 0
                    start_time = time.time()

                    if curr_loss < best_loss:
                        if cfg.verbose:
                            print(
                                'Total loss decreased from {} to {}, saving weights'
                                .format(best_loss, curr_loss))
                        best_loss = curr_loss
                        model_all.save_weights(cfg.model_path)

                    break

            except Exception as e:
                print('Exception: {}'.format(e))
                # save model
                model_all.save_weights(cfg.model_path)
                continue
    print('Training complete, exiting.')