コード例 #1
0
    C.network = 'vgg'
    from keras_frcnn import vgg as nn
elif options.network == 'resnet50':
    from keras_frcnn import resnet as nn
    C.network = 'resnet50'
else:
    print('Not a valid model')
    raise ValueError


# check if weight path was passed via command line
if options.input_weight_path:
    C.base_net_weights = options.input_weight_path
else:
    # set the path to weights based on backend and model
    C.base_net_weights = nn.get_weight_path()

all_imgs, classes_count, class_mapping = get_data(options.train_path)

if 'bg' not in classes_count:
    classes_count['bg'] = 0
    class_mapping['bg'] = len(class_mapping)

C.class_mapping = class_mapping

inv_map = {v: k for k, v in class_mapping.items()}

print('Training images per class:')
pprint.pprint(classes_count)
print('Num classes (including bg) = {}'.format(len(classes_count)))
コード例 #2
0
    C.network = 'vgg19'
elif options.network == 'mobilenetv1':
    from keras_frcnn import mobilenetv1 as nn
    C.network = 'mobilenetv1'
elif options.network == 'mobilenetv2':
    from keras_frcnn import mobilenetv2 as nn
    C.network = 'mobilenetv2'
else:
    print('Not a valid model')
    raise ValueError

if options.input_weight_path:
    C.base_net_weights = options.input_weight_path
else:
    # set the path to weights based on backend and model
    C.base_net_weights = nn.get_weight_path()

all_imgs, classes_count, class_mapping = get_data(options.train_path)

if 'bg' not in classes_count:
    classes_count['bg'] = 0
    class_mapping['bg'] = len(class_mapping)

C.class_mapping = class_mapping

inv_map = {v: k for k, v in class_mapping.items()}

print('Training images per class:')
pprint.pprint(classes_count)
print('Num classes (including bg) = {}'.format(len(classes_count)))
コード例 #3
0
def work(num_ep, textEdit, traval_path, weights, dataAug_hf, dataAug_vf,
         dataAug_rot, base_network, curvelist):

    from keras_frcnn import losses as losses
    ##python /run train_frcnn.py -p ./ --input_weight_path ./model_frcnn.hdf5

    sys.setrecursionlimit(40000)

    num_roi = 128
    num_epoch = num_ep
    model_save = "./model_frcnn.hdf5"

    parser = OptionParser()
    parser.add_option("-p",
                      "--path",
                      dest="train_path",
                      help="Path to training data.",
                      default=traval_path)  ###
    parser.add_option("-o",
                      "--parser",
                      dest="parser",
                      help="Parser to use. One of simple or pascal_voc",
                      default="pascal_voc")
    parser.add_option("-n",
                      "--num_rois",
                      type="int",
                      dest="num_rois",
                      help="Number of RoIs to process at once.",
                      default=num_roi)  ###
    parser.add_option("--network",
                      dest="network",
                      help="Base network to use. Supports vgg or resnet50.",
                      default=base_network)  ###resnet50

    parser.add_option(
        "--hf",
        dest="horizontal_flips",
        help="Augment with horizontal flips in training. (Default=false).",
        action="store_true",
        default=dataAug_hf)  ##
    parser.add_option(
        "--vf",
        dest="vertical_flips",
        help="Augment with vertical flips in training. (Default=false).",
        action="store_true",
        default=dataAug_vf)  ##
    parser.add_option(
        "--rot",
        "--rot_90",
        dest="rot_90",
        help="Augment with 90 degree rotations in training. (Default=false).",
        action="store_true",
        default=dataAug_rot)  ##

    parser.add_option("--num_epochs",
                      type="int",
                      dest="num_epochs",
                      help="Number of epochs.",
                      default=num_epoch)  ###2000
    parser.add_option(
        "--config_filename",
        dest="config_filename",
        help=
        "Location to store all the metadata related to the training (to be used when testing).",
        default="config.pickle")

    parser.add_option("--output_weight_path",
                      dest="output_weight_path",
                      help="Output path for weights.",
                      default=model_save)  ###

    parser.add_option(
        "--input_weight_path",
        dest="input_weight_path",
        help=
        "Input path for weights. If not specified, will try to load default weights provided by keras.",
        default=traval_path + weights)

    (options, args) = parser.parse_args()

    if not options.train_path:  # if filename is not given
        parser.error(
            'Error: path to training data must be specified. Pass --path to command line'
        )

    if options.parser == 'pascal_voc':
        from keras_frcnn.pascal_voc_parser import get_data
    elif options.parser == 'simple':
        from keras_frcnn.simple_parser import get_data
    else:
        raise ValueError(
            "Command line option parser must be one of 'pascal_voc' or 'simple'"
        )

    # pass the settings from the command line, and persist them in the config object
    C = config.Config()

    C.use_horizontal_flips = bool(options.horizontal_flips)
    C.use_vertical_flips = bool(options.vertical_flips)
    C.rot_90 = bool(options.rot_90)

    C.model_path = options.output_weight_path
    C.num_rois = int(options.num_rois)

    if options.network == 'vgg':
        C.network = 'vgg'
        from keras_frcnn import vgg as nn
    elif options.network == 'resnet50':
        from keras_frcnn import resnet as nn
        C.network = 'resnet50'
    else:
        print('Not a valid model')
        raise ValueError

    # check if weight path was passed via command line
    if options.input_weight_path:
        C.base_net_weights = options.input_weight_path
    else:
        # set the path to weights based on backend and model
        C.base_net_weights = nn.get_weight_path()

    all_imgs, classes_count, class_mapping = get_data(
        options.train_path)  #######
    #all_imgs, classes_count, class_mapping = get_data(train_path)#######

    if 'bg' not in classes_count:
        classes_count['bg'] = 0
        class_mapping['bg'] = len(class_mapping)

    C.class_mapping = class_mapping

    inv_map = {v: k for k, v in class_mapping.items()}

    print('Training images per class:')
    textEdit.append('Training images per class:')
    pprint.pprint(classes_count)
    print('Num classes (including bg) = {}'.format(len(classes_count)))
    textEdit.append('Num classes (including bg) = {}'.format(
        len(classes_count)))
    config_output_filename = options.config_filename

    with open(config_output_filename, 'wb') as config_f:
        pickle.dump(C, config_f)
        print(
            'Config has been written to {}, and can be loaded when testing to ensure correct results'
            .format(config_output_filename))
        textEdit.append(
            'Config has been written to {}, and can be loaded when testing to ensure correct results'
            .format(config_output_filename))
    random.shuffle(all_imgs)  ##随机打乱

    num_imgs = len(all_imgs)

    train_imgs = [s for s in all_imgs if s['imageset'] == 'train']
    val_imgs = [s for s in all_imgs if s['imageset'] == 'val']

    print('Num train samples {}'.format(len(train_imgs)))
    textEdit.append('Num train samples {}'.format(len(train_imgs)))
    print('Num val samples {}'.format(len(val_imgs)))
    textEdit.append('Num val samples {}'.format(len(val_imgs)))

    #####RPN计算
    data_gen_train = data_generators.get_anchor_gt(train_imgs,
                                                   classes_count,
                                                   C,
                                                   nn.get_img_output_length,
                                                   K.image_dim_ordering(),
                                                   mode='train')
    data_gen_val = data_generators.get_anchor_gt(val_imgs,
                                                 classes_count,
                                                 C,
                                                 nn.get_img_output_length,
                                                 K.image_dim_ordering(),
                                                 mode='val')

    if K.image_dim_ordering() == 'th':
        input_shape_img = (3, None, None)
    else:
        input_shape_img = (None, None, 3)

    img_input = Input(shape=input_shape_img)
    roi_input = Input(shape=(None, 4))

    # define the base network (resnet here, can be VGG, Inception, etc)
    shared_layers = nn.nn_base(img_input, trainable=True)

    # define the RPN, built on the base layers
    num_anchors = len(C.anchor_box_scales) * len(C.anchor_box_ratios)
    rpn = nn.rpn(shared_layers, num_anchors)

    classifier = nn.classifier(shared_layers,
                               roi_input,
                               C.num_rois,
                               nb_classes=len(classes_count),
                               trainable=True)

    model_rpn = Model(img_input, rpn[:2])
    model_classifier = Model([img_input, roi_input], classifier)

    # this is a model that holds both the RPN and the classifier, used to load/save weights for the models
    model_all = Model([img_input, roi_input], rpn[:2] + classifier)

    try:
        print('loading weights from {}'.format(C.base_net_weights))
        textEdit.append('loading weights from {}'.format(C.base_net_weights))
        model_rpn.load_weights(C.base_net_weights, by_name=True)
        model_classifier.load_weights(C.base_net_weights, by_name=True)
    except:
        print(
            'Could not load pretrained model weights. Weights can be found in the keras application folder \
			https://github.com/fchollet/keras/tree/master/keras/applications')

    optimizer = Adam(lr=1e-5)
    optimizer_classifier = Adam(lr=1e-5)
    model_rpn.compile(optimizer=optimizer,
                      loss=[
                          losses.rpn_loss_cls(num_anchors),
                          losses.rpn_loss_regr(num_anchors)
                      ])
    model_classifier.compile(
        optimizer=optimizer_classifier,
        loss=[
            losses.class_loss_cls,
            losses.class_loss_regr(len(classes_count) - 1)
        ],
        metrics={'dense_class_{}'.format(len(classes_count)): 'accuracy'})
    model_all.compile(optimizer='sgd', loss='mae')
    print(' training')
    epoch_length = 1
    num_epochs = int(options.num_epochs)
    iter_num = 0

    losses = np.zeros((epoch_length, 5))
    rpn_accuracy_rpn_monitor = []
    rpn_accuracy_for_epoch = []
    start_time = time.time()

    best_loss = np.Inf

    class_mapping_inv = {v: k for k, v in class_mapping.items()}
    print('Starting training')
    textEdit.append('Starting training')

    vis = True
    save_num = 1

    loss_rpn_cls_list = []
    loss_rpn_regr_list = []
    loss_class_cls_list = []
    loss_class_regr_list = []
    curr_loss_list = []

    for epoch_num in range(num_epochs):

        progbar = generic_utils.Progbar(epoch_length)
        print('Epoch {}/{}'.format(epoch_num + 1, num_epochs))
        textEdit.append('Epoch {}/{}'.format(epoch_num + 1, num_epochs))
        while True:
            try:

                if len(rpn_accuracy_rpn_monitor) == epoch_length and C.verbose:
                    mean_overlapping_bboxes = float(
                        sum(rpn_accuracy_rpn_monitor)) / len(
                            rpn_accuracy_rpn_monitor)
                    rpn_accuracy_rpn_monitor = []
                    print(
                        'Average number of overlapping bounding boxes from RPN = {} for {} previous iterations'
                        .format(mean_overlapping_bboxes, epoch_length))
                    if mean_overlapping_bboxes == 0:
                        print(
                            'RPN is not producing bounding boxes that overlap the ground truth boxes. Check RPN settings or keep training.'
                        )

                X, Y, img_data = next(data_gen_train)

                loss_rpn = model_rpn.train_on_batch(X, Y)

                P_rpn = model_rpn.predict_on_batch(X)

                R = roi_helpers.rpn_to_roi(P_rpn[0],
                                           P_rpn[1],
                                           C,
                                           K.image_dim_ordering(),
                                           use_regr=True,
                                           overlap_thresh=0.7,
                                           max_boxes=300)
                # note: calc_iou converts from (x1,y1,x2,y2) to (x,y,w,h) format
                X2, Y1, Y2, IouS = roi_helpers.calc_iou(
                    R, img_data, C, class_mapping)

                if X2 is None:
                    rpn_accuracy_rpn_monitor.append(0)
                    rpn_accuracy_for_epoch.append(0)
                    continue

                neg_samples = np.where(Y1[0, :, -1] == 1)
                pos_samples = np.where(Y1[0, :, -1] == 0)

                if len(neg_samples) > 0:
                    neg_samples = neg_samples[0]
                else:
                    neg_samples = []

                if len(pos_samples) > 0:
                    pos_samples = pos_samples[0]
                else:
                    pos_samples = []

                rpn_accuracy_rpn_monitor.append(len(pos_samples))
                rpn_accuracy_for_epoch.append((len(pos_samples)))

                if C.num_rois > 1:
                    if len(pos_samples) < C.num_rois // 2:
                        selected_pos_samples = pos_samples.tolist()
                    else:
                        selected_pos_samples = np.random.choice(
                            pos_samples, C.num_rois // 2,
                            replace=False).tolist()
                    try:
                        selected_neg_samples = np.random.choice(
                            neg_samples,
                            C.num_rois - len(selected_pos_samples),
                            replace=False).tolist()
                    except:
                        selected_neg_samples = np.random.choice(
                            neg_samples,
                            C.num_rois - len(selected_pos_samples),
                            replace=True).tolist()

                    sel_samples = selected_pos_samples + selected_neg_samples
                else:
                    # in the extreme case where num_rois = 1, we pick a random pos or neg sample
                    selected_pos_samples = pos_samples.tolist()
                    selected_neg_samples = neg_samples.tolist()
                    if np.random.randint(0, 2):
                        sel_samples = random.choice(neg_samples)
                    else:
                        sel_samples = random.choice(pos_samples)

                loss_class = model_classifier.train_on_batch(
                    [X, X2[:, sel_samples, :]],
                    [Y1[:, sel_samples, :], Y2[:, sel_samples, :]])

                losses[iter_num, 0] = loss_rpn[1]
                losses[iter_num, 1] = loss_rpn[2]

                losses[iter_num, 2] = loss_class[1]
                losses[iter_num, 3] = loss_class[2]
                losses[iter_num, 4] = loss_class[3]

                iter_num += 1

                ####数据更新在这里
                progbar.update(
                    iter_num,
                    [('rpn_cls', np.mean(losses[:iter_num, 0])),
                     ('rpn_regr', np.mean(losses[:iter_num, 1])),
                     ('detector_cls', np.mean(losses[:iter_num, 2])),
                     ('detector_regr', np.mean(losses[:iter_num, 3]))])

                if iter_num == epoch_length:
                    loss_rpn_cls = np.mean(losses[:, 0])
                    loss_rpn_cls_list.append(loss_rpn_cls)
                    curvelist[1].setData(loss_rpn_cls_list)
                    loss_rpn_regr = np.mean(losses[:, 1])
                    loss_rpn_regr_list.append(loss_rpn_regr)
                    curvelist[2].setData(loss_rpn_regr_list)
                    loss_class_cls = np.mean(losses[:, 2])
                    loss_class_cls_list.append(loss_class_cls)
                    curvelist[3].setData(loss_class_cls_list)
                    loss_class_regr = np.mean(losses[:, 3])
                    loss_class_regr_list.append(loss_class_regr)
                    curvelist[4].setData(loss_class_regr_list)
                    class_acc = np.mean(losses[:, 4])

                    mean_overlapping_bboxes = float(sum(
                        rpn_accuracy_for_epoch)) / len(rpn_accuracy_for_epoch)
                    rpn_accuracy_for_epoch = []

                    if C.verbose:
                        print(
                            'Mean number of bounding boxes from RPN overlapping ground truth boxes: {}'
                            .format(mean_overlapping_bboxes))
                        textEdit.append(
                            'Mean number of bounding boxes from RPN overlapping ground truth boxes: {}'
                            .format(mean_overlapping_bboxes))
                        print(
                            'Classifier accuracy for bounding boxes from RPN: {}'
                            .format(class_acc))
                        textEdit.append(
                            'Classifier accuracy for bounding boxes from RPN: {}'
                            .format(class_acc))
                        print('Loss RPN classifier: {}'.format(loss_rpn_cls))
                        textEdit.append(
                            'Loss RPN classifier: {}'.format(loss_rpn_cls))
                        print('Loss RPN regression: {}'.format(loss_rpn_regr))
                        textEdit.append(
                            'Loss RPN regression: {}'.format(loss_rpn_regr))
                        print('Loss Detector classifier: {}'.format(
                            loss_class_cls))
                        textEdit.append('Loss Detector classifier: {}'.format(
                            loss_class_cls))
                        print('Loss Detector regression: {}'.format(
                            loss_class_regr))
                        textEdit.append('Loss Detector regression: {}'.format(
                            loss_class_regr))
                        print('Elapsed time: {}'.format(time.time() -
                                                        start_time))
                        textEdit.append('Elapsed time: {}'.format(time.time() -
                                                                  start_time))

                    curr_loss = loss_rpn_cls + loss_rpn_regr + loss_class_cls + loss_class_regr
                    curr_loss_list.append(curr_loss)
                    curvelist[0].setData(curr_loss_list)
                    iter_num = 0
                    start_time = time.time()

                    if curr_loss < best_loss:
                        if C.verbose:
                            print(
                                'Total loss decreased from {} to {}, saving weights'
                                .format(best_loss, curr_loss))
                            textEdit.append(
                                'Total loss decreased from {} to {}, saving weights'
                                .format(best_loss, curr_loss))
                        best_loss = curr_loss
                        model_all.save_weights(C.model_path)
                    break

            except Exception as e:
                print('Exception: {}'.format(e))
                textEdit.append('Exception: {}'.format(e))
                continue
        if (epoch_num == (save_num * 30)):
            model_all.save_weights(C.model_path + str(save_num * 30) + ".hdf5")
            save_num = save_num + 1
    print('Training complete, exiting.')
    textEdit.append('Training complete, exiting.')
コード例 #4
0
def train_kitti():
    # config for data argument
    cfg = config.Config()

    cfg.use_horizontal_flips = True
    cfg.use_vertical_flips = True
    cfg.rot_90 = True
    cfg.num_rois = 32
    cfg.base_net_weights = os.path.join('./model/', nn.get_weight_path())
    # cfg.base_net_weights=r''

    # TODO: the only file should to be change for other data to train
    cfg.model_path = '/media/private/Ci/log/plane/frcnn/vgg-adam'

    now = datetime.datetime.now()
    day = now.strftime('%y-%m-%d')
    for i in range(10000):
        if not os.path.exists('%s-%s-%d' % (cfg.model_path, day, i)):
            cfg.model_path = '%s-%s-%d' % (cfg.model_path, day, i)
            break

    make_dir(cfg.model_path)
    make_dir(cfg.model_path + '/loss')
    make_dir(cfg.model_path + '/loss_rpn_cls')
    make_dir(cfg.model_path + '/loss_rpn_regr')
    make_dir(cfg.model_path + '/loss_class_cls')
    make_dir(cfg.model_path + '/loss_class_regr')

    cfg.simple_label_file = '/media/public/GEOWAY/plane/plane0817.csv'

    all_images, classes_count, class_mapping = get_data(cfg.simple_label_file)

    if 'bg' not in classes_count:
        classes_count['bg'] = 0
        class_mapping['bg'] = len(class_mapping)

    cfg.class_mapping = class_mapping
    cfg.config_save_file = os.path.join(cfg.model_path, 'config.pickle')
    with open(cfg.config_save_file, 'wb') as config_f:
        pickle.dump(cfg, config_f)
        print(
            'Config has been written to {}, and can be loaded when testing to ensure correct results'
            .format(cfg.config_save_file))

    inv_map = {v: k for k, v in class_mapping.items()}

    print('Training images per class:')
    pprint.pprint(classes_count)
    print('Num classes (including bg) = {}'.format(len(classes_count)))
    random.shuffle(all_images)
    num_imgs = len(all_images)
    train_imgs = [s for s in all_images if s['imageset'] == 'trainval']
    val_imgs = [s for s in all_images if s['imageset'] == 'test']

    print('Num train samples {}'.format(len(train_imgs)))
    print('Num val samples {}'.format(len(val_imgs)))

    data_gen_train = data_generators.get_anchor_gt(train_imgs,
                                                   classes_count,
                                                   cfg,
                                                   nn.get_img_output_length,
                                                   K.image_dim_ordering(),
                                                   mode='train')
    data_gen_val = data_generators.get_anchor_gt(val_imgs,
                                                 classes_count,
                                                 cfg,
                                                 nn.get_img_output_length,
                                                 K.image_dim_ordering(),
                                                 mode='val')
    Q = multiprocessing.Manager().Queue(maxsize=30)

    def fill_Q(n):
        while True:

            if not Q.full():
                Q.put(next(data_gen_train))
                #print(Q.qsize(),'put',n)
            else:
                time.sleep(0.00001)

    threads = []
    for i in range(4):
        thread = multiprocessing.Process(target=fill_Q, args=(i, ))
        threads.append(thread)
        thread.start()

    if K.image_dim_ordering() == 'th':
        input_shape_img = (3, None, None)
    else:
        input_shape_img = (None, None, 3)

    img_input = Input(shape=input_shape_img)
    roi_input = Input(shape=(None, 4))

    # define the base network (resnet here, can be VGG, Inception, etc)
    shared_layers = nn.nn_base(img_input, trainable=True)

    # define the RPN, built on the base layers
    num_anchors = len(cfg.anchor_box_scales) * len(cfg.anchor_box_ratios)
    rpn = nn.rpn(shared_layers, num_anchors)

    classifier = nn.classifier(shared_layers,
                               roi_input,
                               cfg.num_rois,
                               nb_classes=len(classes_count),
                               trainable=True)

    model_rpn = Model(img_input, rpn[:2])
    model_classifier = Model([img_input, roi_input], classifier)

    # this is a model that holds both the RPN and the classifier, used to load/save weights for the models
    model_all = Model([img_input, roi_input], rpn[:2] + classifier)
    # model_all.summary()
    from keras.utils import plot_model
    # os.environ['PATH'] = os.environ['PATH'] + r';C:\Program Files (x86)\Graphviz2.38\bin;'

    plot_model(model_all,
               'model_all.png',
               show_layer_names=True,
               show_shapes=True)
    plot_model(model_classifier,
               'model_classifier.png',
               show_layer_names=True,
               show_shapes=True)
    plot_model(model_rpn,
               'model_rpn.png',
               show_layer_names=True,
               show_shapes=True)
    '''
    try:
        print('loading weights from {}'.format(cfg.base_net_weights))
        model_rpn.load_weights(cfg.model_path, by_name=True)
        model_classifier.load_weights(cfg.model_path, by_name=True)
    except Exception as e:
        print(e)
        print('Could not load pretrained model weights. Weights can be found in the keras application folder '
              'https://github.com/fchollet/keras/tree/master/keras/applications')
    '''

    optimizer = adadelta()
    optimizer_classifier = adadelta()
    model_rpn.compile(optimizer=optimizer,
                      loss=[
                          losses_fn.rpn_loss_cls(num_anchors),
                          losses_fn.rpn_loss_regr(num_anchors)
                      ])
    model_classifier.compile(
        optimizer=optimizer_classifier,
        loss=[
            losses_fn.class_loss_cls,
            losses_fn.class_loss_regr(len(classes_count) - 1)
        ],
        metrics={'dense_class_{}'.format(len(classes_count)): 'accuracy'})
    model_all.compile(optimizer='sgd', loss='mae')

    epoch_length = 10
    num_epochs = int(cfg.num_epochs)
    iter_num = 0

    losses = np.zeros((epoch_length, 5))
    rpn_accuracy_rpn_monitor = []
    rpn_accuracy_for_epoch = []
    start_time = time.time()

    best_loss = np.Inf
    best_rpn_cls = np.Inf
    best_rpn_regr = np.Inf
    best_class_cls = np.Inf
    best_class_regr = np.Inf

    class_mapping_inv = {v: k for k, v in class_mapping.items()}
    print('Starting training')

    vis = True

    for epoch_num in range(num_epochs):

        progbar = generic_utils.Progbar(epoch_length)
        print('Epoch {}/{}'.format(epoch_num + 1, num_epochs))

        while True:
            try:

                if len(rpn_accuracy_rpn_monitor
                       ) == epoch_length and cfg.verbose:
                    mean_overlapping_bboxes = float(
                        sum(rpn_accuracy_rpn_monitor)) / len(
                            rpn_accuracy_rpn_monitor)
                    rpn_accuracy_rpn_monitor = []
                    print(
                        'Average number of overlapping bounding boxes from RPN = {} for {} previous iterations'
                        .format(mean_overlapping_bboxes, epoch_length))
                    if mean_overlapping_bboxes == 0:
                        print(
                            'RPN is not producing bounding boxes that overlap'
                            ' the ground truth boxes. Check RPN settings or keep training.'
                        )

            #    X, Y, img_data = next(data_gen_train)
                while True:

                    if Q.empty():
                        time.sleep(0.00001)
                        continue

                    X, Y, img_data = Q.get()
                    #    print(Q.qsize(),'get')
                    break
            #  print(X.shape,Y.shape)
                loss_rpn = model_rpn.train_on_batch(X, Y)

                P_rpn = model_rpn.predict_on_batch(X)

                result = roi_helpers.rpn_to_roi(P_rpn[0],
                                                P_rpn[1],
                                                cfg,
                                                K.image_dim_ordering(),
                                                use_regr=True,
                                                overlap_thresh=0.7,
                                                max_boxes=300)
                # note: calc_iou converts from (x1,y1,x2,y2) to (x,y,w,h) format
                X2, Y1, Y2, IouS = roi_helpers.calc_iou(
                    result, img_data, cfg, class_mapping)

                if X2 is None:
                    rpn_accuracy_rpn_monitor.append(0)
                    rpn_accuracy_for_epoch.append(0)
                    continue

                neg_samples = np.where(Y1[0, :, -1] == 1)
                pos_samples = np.where(Y1[0, :, -1] == 0)

                if len(neg_samples) > 0:
                    neg_samples = neg_samples[0]
                else:
                    neg_samples = []

                if len(pos_samples) > 0:
                    pos_samples = pos_samples[0]
                else:
                    pos_samples = []

                rpn_accuracy_rpn_monitor.append(len(pos_samples))
                rpn_accuracy_for_epoch.append((len(pos_samples)))

                if cfg.num_rois > 1:
                    if len(pos_samples) < cfg.num_rois // 2:
                        selected_pos_samples = pos_samples.tolist()
                    else:
                        selected_pos_samples = np.random.choice(
                            pos_samples, cfg.num_rois // 2,
                            replace=False).tolist()
                    try:
                        selected_neg_samples = np.random.choice(
                            neg_samples,
                            cfg.num_rois - len(selected_pos_samples),
                            replace=False).tolist()
                    except:
                        selected_neg_samples = np.random.choice(
                            neg_samples,
                            cfg.num_rois - len(selected_pos_samples),
                            replace=True).tolist()

                    sel_samples = selected_pos_samples + selected_neg_samples
                else:
                    # in the extreme case where num_rois = 1, we pick a random pos or neg sample
                    selected_pos_samples = pos_samples.tolist()
                    selected_neg_samples = neg_samples.tolist()
                    if np.random.randint(0, 2):
                        sel_samples = random.choice(neg_samples)
                    else:
                        sel_samples = random.choice(pos_samples)

                loss_class = model_classifier.train_on_batch(
                    [X, X2[:, sel_samples, :]],
                    [Y1[:, sel_samples, :], Y2[:, sel_samples, :]])

                losses[iter_num, 0] = loss_rpn[1]
                losses[iter_num, 1] = loss_rpn[2]

                losses[iter_num, 2] = loss_class[1]
                losses[iter_num, 3] = loss_class[2]
                losses[iter_num, 4] = loss_class[3]

                iter_num += 1

                progbar.update(
                    iter_num,
                    [('rpn_cls', np.mean(losses[:iter_num, 0])),
                     ('rpn_regr', np.mean(losses[:iter_num, 1])),
                     ('detector_cls', np.mean(losses[:iter_num, 2])),
                     ('detector_regr', np.mean(losses[:iter_num, 3]))])

                if iter_num == epoch_length:
                    loss_rpn_cls = np.mean(losses[:, 0])
                    loss_rpn_regr = np.mean(losses[:, 1])
                    loss_class_cls = np.mean(losses[:, 2])
                    loss_class_regr = np.mean(losses[:, 3])
                    class_acc = np.mean(losses[:, 4])

                    mean_overlapping_bboxes = float(sum(
                        rpn_accuracy_for_epoch)) / len(rpn_accuracy_for_epoch)
                    rpn_accuracy_for_epoch = []

                    if cfg.verbose:
                        print(
                            'Mean number of bounding boxes from RPN overlapping ground truth boxes: {}'
                            .format(mean_overlapping_bboxes))
                        print(
                            'Classifier accuracy for bounding boxes from RPN: {}'
                            .format(class_acc))
                        print('Loss RPN classifier: {}'.format(loss_rpn_cls))
                        print('Loss RPN regression: {}'.format(loss_rpn_regr))
                        print('Loss Detector classifier: {}'.format(
                            loss_class_cls))
                        print('Loss Detector regression: {}'.format(
                            loss_class_regr))
                        print('Elapsed time: {}'.format(time.time() -
                                                        start_time))

                    curr_loss = loss_rpn_cls + loss_rpn_regr + loss_class_cls + loss_class_regr
                    iter_num = 0
                    start_time = time.time()

                    if curr_loss < best_loss:
                        if cfg.verbose:
                            print(
                                'Total loss decreased from {} to {}, saving weights'
                                .format(best_loss, curr_loss))
                        best_loss = curr_loss
                        model_all.save_weights(
                            '%s/%s/E-%d-loss-%.4f-rpnc-%.4f-rpnr-%.4f-cls-%.4f-cr-%.4f.hdf5'
                            % (cfg.model_path, 'loss', epoch_num, curr_loss,
                               loss_rpn_cls, loss_rpn_regr, loss_class_cls,
                               loss_class_regr))
                    if loss_rpn_cls < best_rpn_cls:
                        if cfg.verbose:
                            print(
                                'loss_rpn_cls decreased from {} to {}, saving weights'
                                .format(best_rpn_cls, loss_rpn_cls))
                            best_rpn_cls = loss_rpn_cls
                        model_all.save_weights(
                            '%s/%s/E-%d-loss-%.4f-rpnc-%.4f-rpnr-%.4f-cls-%.4f-cr-%.4f.hdf5'
                            % (cfg.model_path, 'loss_rpn_cls', epoch_num,
                               curr_loss, loss_rpn_cls, loss_rpn_regr,
                               loss_class_cls, loss_class_regr))
                    if loss_rpn_regr < best_rpn_regr:
                        if cfg.verbose:
                            print(
                                'loss_rpn_regr decreased from {} to {}, saving weights'
                                .format(best_rpn_regr, loss_rpn_regr))
                            best_rpn_regr = loss_rpn_regr
                        model_all.save_weights(
                            '%s/%s/E-%d-loss-%.4f-rpnc-%.4f-rpnr-%.4f-cls-%.4f-cr-%.4f.hdf5'
                            % (cfg.model_path, 'loss_rpn_regr', epoch_num,
                               curr_loss, loss_rpn_cls, loss_rpn_regr,
                               loss_class_cls, loss_class_regr))
                    if loss_class_cls < best_class_cls:
                        if cfg.verbose:
                            print(
                                'loss_class_cls decreased from {} to {}, saving weights'
                                .format(best_loss, loss_class_cls))
                            best_class_cls = loss_class_cls
                        model_all.save_weights(
                            '%s/%s/E-%d-loss-%.4f-rpnc-%.4f-rpnr-%.4f-cls-%.4f-cr-%.4f.hdf5'
                            % (cfg.model_path, 'loss_class_cls', epoch_num,
                               curr_loss, loss_rpn_cls, loss_rpn_regr,
                               loss_class_cls, loss_class_regr))
                    if loss_class_regr < best_class_regr:
                        if cfg.verbose:
                            print(
                                'loss_class_regr decreased from {} to {}, saving weights'
                                .format(best_loss, loss_class_regr))
                            best_class_regr = loss_class_regr
                        model_all.save_weights(
                            '%s/%s/E-%d-loss-%.4f-rpnc-%.4f-rpnr-%.4f-cls-%.4f-cr-%.4f.hdf5'
                            % (cfg.model_path, 'loss_class_regr', epoch_num,
                               curr_loss, loss_rpn_cls, loss_rpn_regr,
                               loss_class_cls, loss_class_regr))

                    break

            except Exception as e:
                #   print('Exception: {}'.format(e))
                # save model
                #    model_all.save_weights(cfg.model_path)
                continue
    print('Training complete, exiting.')
コード例 #5
0
def run_train(train_path,
              output_weight_path,
              config_filename,
              parser='simple',
              input_weight_path=None,
              network='resnet50',
              num_rois=32,
              lr=1e-5,
              iters=100,
              num_epochs=100,
              overlap_th=0.7):

    C = config.Config()

    C.model_path = output_weight_path
    C.num_rois = int(num_rois)

    if network == 'vgg':
        C.network = 'vgg'
        from keras_frcnn import vgg as nn
    elif network == 'resnet50':
        from keras_frcnn import resnet as nn
        C.network = 'resnet50'
    else:
        print('Not a valid model')
        raise ValueError

    if parser == 'pascal_voc':
        from keras_frcnn.pascal_voc_parser import get_data
    elif parser == 'simple':
        from keras_frcnn.simple_parser import get_data
    else:
        print('Wrong parser method')
        raise ValueError

    if input_weight_path is not None:
        C.base_net_weights = input_weight_path
    else:
        C.base_net_weights = nn.get_weight_path()

    all_imgs, classes_count, class_mapping = get_data(train_path)

    if 'bg' not in classes_count:
        classes_count['bg'] = 0
        class_mapping['bg'] = len(class_mapping)

    C.class_mapping = class_mapping

    inv_map = {v: k for k, v in class_mapping.items()}

    print('Training images per class:')
    pprint.pprint(classes_count)
    print('Num classes (including bg) = {}'.format(len(classes_count)))

    config_output_filename = config_filename

    with open(config_output_filename, 'wb') as config_f:
        pickle.dump(C, config_f)
        print(
            'Config has been written to {}, and can be loaded when testing to ensure correct results'
            .format(config_output_filename))

    random.shuffle(all_imgs)

    num_imgs = len(all_imgs)

    train_imgs = [s for s in all_imgs if s['imageset'] == 'trainval']
    val_imgs = [s for s in all_imgs if s['imageset'] == 'test']

    print('Num train samples {}'.format(len(train_imgs)))
    print('Num val samples {}'.format(len(val_imgs)))

    data_gen_train = data_generators.get_anchor_gt(train_imgs,
                                                   classes_count,
                                                   C,
                                                   nn.get_img_output_length,
                                                   K.image_dim_ordering(),
                                                   mode='train')
    data_gen_val = data_generators.get_anchor_gt(val_imgs,
                                                 classes_count,
                                                 C,
                                                 nn.get_img_output_length,
                                                 K.image_dim_ordering(),
                                                 mode='val')

    if K.image_dim_ordering() == 'th':
        input_shape_img = (3, None, None)
    else:
        input_shape_img = (None, None, 3)

    img_input = Input(shape=input_shape_img)
    roi_input = Input(shape=(None, 4))

    # define the base network (resnet here, can be VGG, Inception, etc)
    shared_layers = nn.nn_base(img_input, trainable=True)

    # define the RPN, built on the base layers
    num_anchors = len(C.anchor_box_scales) * len(C.anchor_box_ratios)
    rpn = nn.rpn(shared_layers, num_anchors)

    classifier = nn.classifier(shared_layers,
                               roi_input,
                               C.num_rois,
                               nb_classes=len(classes_count),
                               trainable=True)

    model_rpn = Model(img_input, rpn[:2])
    model_classifier = Model([img_input, roi_input], classifier)

    # this is a model that holds both the RPN and the classifier, used to load/save weights for the models
    model_all = Model([img_input, roi_input], rpn[:2] + classifier)

    try:
        print('loading weights from {}'.format(C.base_net_weights))
        model_rpn.load_weights(C.base_net_weights, by_name=True)
        model_classifier.load_weights(C.base_net_weights, by_name=True)
    except:
        print('Could not load pretrained model weights...')

    optimizer = Adam(lr=lr)
    optimizer_classifier = Adam(lr=lr)
    from keras_frcnn import losses as losses
    model_rpn.compile(optimizer=optimizer,
                      loss=[
                          losses.rpn_loss_cls(num_anchors),
                          losses.rpn_loss_regr(num_anchors)
                      ])
    model_classifier.compile(
        optimizer=optimizer_classifier,
        loss=[
            losses.class_loss_cls,
            losses.class_loss_regr(len(classes_count) - 1)
        ],
        metrics={'dense_class_{}'.format(len(classes_count)): 'accuracy'})
    model_all.compile(optimizer='sgd', loss='mae')

    epoch_length = int(iters)
    num_epochs = int(num_epochs)
    iter_num = 0

    losses = np.zeros((epoch_length, 5))
    rpn_accuracy_rpn_monitor = []
    rpn_accuracy_for_epoch = []
    start_time = time.time()

    best_loss = np.Inf

    class_mapping_inv = {v: k for k, v in class_mapping.items()}
    print('Starting training')

    vis = True

    for epoch_num in range(num_epochs):

        progbar = generic_utils.Progbar(epoch_length)
        print('Epoch {}/{}'.format(epoch_num + 1, num_epochs))

        while True:
            try:
                if len(rpn_accuracy_rpn_monitor) == epoch_length and C.verbose:
                    mean_overlapping_bboxes = float(
                        sum(rpn_accuracy_rpn_monitor)) / len(
                            rpn_accuracy_rpn_monitor)
                    rpn_accuracy_rpn_monitor = []
                    print(
                        'Average number of overlapping bounding boxes from RPN = {} for {} previous iterations'
                        .format(mean_overlapping_bboxes, epoch_length))
                    if mean_overlapping_bboxes == 0:
                        print(
                            'RPN is not producing bounding boxes that overlap the ground truth boxes. Check RPN settings or keep training.'
                        )

                X, Y, img_data = next(data_gen_train)

                loss_rpn = model_rpn.train_on_batch(X, Y)

                P_rpn = model_rpn.predict_on_batch(X)

                R = roi_helpers.rpn_to_roi(P_rpn[0],
                                           P_rpn[1],
                                           C,
                                           K.image_dim_ordering(),
                                           use_regr=True,
                                           overlap_thresh=overlap_th,
                                           max_boxes=300)
                # note: calc_iou converts from (x1,y1,x2,y2) to (x,y,w,h) format
                X2, Y1, Y2, IouS = roi_helpers.calc_iou(
                    R, img_data, C, class_mapping)

                if X2 is None:
                    rpn_accuracy_rpn_monitor.append(0)
                    rpn_accuracy_for_epoch.append(0)
                    continue

                neg_samples = np.where(Y1[0, :, -1] == 1)
                pos_samples = np.where(Y1[0, :, -1] == 0)

                if len(neg_samples) > 0:
                    neg_samples = neg_samples[0]
                else:
                    neg_samples = []

                if len(pos_samples) > 0:
                    pos_samples = pos_samples[0]
                else:
                    pos_samples = []

                rpn_accuracy_rpn_monitor.append(len(pos_samples))
                rpn_accuracy_for_epoch.append((len(pos_samples)))

                if C.num_rois > 1:
                    if len(pos_samples) < C.num_rois // 2:
                        selected_pos_samples = pos_samples.tolist()
                    else:
                        selected_pos_samples = np.random.choice(
                            pos_samples, C.num_rois // 2,
                            replace=False).tolist()
                    try:
                        selected_neg_samples = np.random.choice(
                            neg_samples,
                            C.num_rois - len(selected_pos_samples),
                            replace=False).tolist()
                    except:
                        selected_neg_samples = np.random.choice(
                            neg_samples,
                            C.num_rois - len(selected_pos_samples),
                            replace=True).tolist()

                    sel_samples = selected_pos_samples + selected_neg_samples
                else:
                    # in the extreme case where num_rois = 1, we pick a random pos or neg sample
                    selected_pos_samples = pos_samples.tolist()
                    selected_neg_samples = neg_samples.tolist()
                    if np.random.randint(0, 2):
                        sel_samples = random.choice(neg_samples)
                    else:
                        sel_samples = random.choice(pos_samples)

                loss_class = model_classifier.train_on_batch(
                    [X, X2[:, sel_samples, :]],
                    [Y1[:, sel_samples, :], Y2[:, sel_samples, :]])

                losses[iter_num, 0] = loss_rpn[1]
                losses[iter_num, 1] = loss_rpn[2]

                losses[iter_num, 2] = loss_class[1]
                losses[iter_num, 3] = loss_class[2]
                losses[iter_num, 4] = loss_class[3]

                iter_num += 1

                progbar.update(
                    iter_num,
                    [('rpn_cls', np.mean(losses[:iter_num, 0])),
                     ('rpn_regr', np.mean(losses[:iter_num, 1])),
                     ('detector_cls', np.mean(losses[:iter_num, 2])),
                     ('detector_regr', np.mean(losses[:iter_num, 3]))])

                if iter_num == epoch_length:
                    loss_rpn_cls = np.mean(losses[:, 0])
                    loss_rpn_regr = np.mean(losses[:, 1])
                    loss_class_cls = np.mean(losses[:, 2])
                    loss_class_regr = np.mean(losses[:, 3])
                    class_acc = np.mean(losses[:, 4])

                    mean_overlapping_bboxes = float(sum(
                        rpn_accuracy_for_epoch)) / len(rpn_accuracy_for_epoch)
                    rpn_accuracy_for_epoch = []

                    if C.verbose:
                        print(
                            'Mean number of bounding boxes from RPN overlapping ground truth boxes: {}'
                            .format(mean_overlapping_bboxes))
                        print(
                            'Classifier accuracy for bounding boxes from RPN: {}'
                            .format(class_acc))
                        print('Loss RPN classifier: {}'.format(loss_rpn_cls))
                        print('Loss RPN regression: {}'.format(loss_rpn_regr))
                        print('Loss Detector classifier: {}'.format(
                            loss_class_cls))
                        print('Loss Detector regression: {}'.format(
                            loss_class_regr))
                        print('Elapsed time: {}'.format(time.time() -
                                                        start_time))

                    curr_loss = loss_rpn_cls + loss_rpn_regr + loss_class_cls + loss_class_regr
                    iter_num = 0
                    start_time = time.time()

                    if curr_loss < best_loss:
                        if C.verbose:
                            print(
                                'Total loss decreased from {} to {}, saving weights'
                                .format(best_loss, curr_loss))
                        best_loss = curr_loss
                        model_all.save_weights(C.model_path)

                    break

            except Exception as e:
                print('Exception: {}'.format(e))
                continue

    print('Training complete, exiting.')
コード例 #6
0
if options.network == 'vgg':
    C.network = 'vgg'
    from keras_frcnn import vgg as nn
elif options.network == 'resnet50':
    from keras_frcnn import resnet as nn
    C.network = 'resnet50'
else:
    print('Not a valid model')
    raise ValueError

# check if weight path was passed via command line
if options.input_weight_path:
    C.base_net_weights = options.input_weight_path
else:
    # set the path to weights based on backend and model
    C.base_net_weights = nn.get_weight_path(
    )  #'resnet50_weights_th_dim_ordering_th_kernels_notop.h5'

all_imgs, classes_count, class_mapping = get_data(options.train_path)

if 'bg' not in classes_count:
    classes_count['bg'] = 0
    class_mapping['bg'] = len(class_mapping)

C.class_mapping = class_mapping

inv_map = {v: k for k, v in class_mapping.items()}

print('Training images per class:')
pprint.pprint(classes_count)
print('Num classes (including bg) = {}'.format(len(classes_count)))
コード例 #7
0
def train_kitti():
    # config for data argument
    cfg = config.Config()

    cfg.use_horizontal_flips = True
    cfg.use_vertical_flips = True
    cfg.rot_90 = True
    cfg.num_rois = 32
    cfg.base_net_weights = os.path.join('./model/', nn.get_weight_path())

    # TODO: the only file should to be change for other data to train
    cfg.model_path = './model/kitti_frcnn_last.hdf5'
    cfg.simple_label_file = 'kitti_simple_label.txt'

    all_images, classes_count, class_mapping = get_data(cfg.simple_label_file)

    # all_images,一个dict 将 文件地址,图像尺寸,以及所有bounding box 的所有位置标出来,以及这个图像是用来trainval还是test:
    # {'filepath': './Datasets/training/image_2/000001.png', 'width': 1242, 'height': 375,
    #  'bboxes': [{'class': 'Truck', 'x1': 599, 'x2': 629, 'y1': 156, 'y2': 189},
    #             {'class': 'Car', 'x1': 387, 'x2': 423, 'y1': 181, 'y2': 203},
    #             {'class': 'Cyclist', 'x1': 676, 'x2': 688, 'y1': 163, 'y2': 193},
    #             {'class': 'DontCare', 'x1': 503, 'x2': 590, 'y1': 169, 'y2': 190},
    #             {'class': 'DontCare', 'x1': 511, 'x2': 527, 'y1': 174, 'y2': 187},
    #             {'class': 'DontCare', 'x1': 532, 'x2': 542, 'y1': 176, 'y2': 185},
    #             {'class': 'DontCare', 'x1': 559, 'x2': 575, 'y1': 175, 'y2': 183}], 'imageset': 'trainval'}

    # classes_count 一个dict:
    # {'Pedestrian': 3, 'Truck': 1, 'Car': 29, 'Cyclist': 2, 'DontCare': 29, 'Misc': 1}
    # class_mapping 一个 dict:
    # {'Pedestrian': 0, 'Truck': 1, 'Car': 2, 'Cyclist': 3, 'DontCare': 4, 'Misc': 5}

    # bg 类为 background 类,有时候增加 hard negative mining 会增加一些什么都没有的background 的背景类
    if 'bg' not in classes_count:
        classes_count['bg'] = 0
        class_mapping['bg'] = len(class_mapping)

    cfg.class_mapping = class_mapping
    with open(cfg.config_save_file, 'wb') as config_f:
        pickle.dump(cfg, config_f)
        print(
            'Config has been written to {}, and can be loaded when testing to ensure correct results'
            .format(cfg.config_save_file))

    inv_map = {v: k for k, v in class_mapping.items()}

    print('Training images per class:')
    pprint.pprint(classes_count)
    print('Num classes (including bg) = {}'.format(len(classes_count)))
    random.shuffle(all_images)
    num_imgs = len(all_images)
    train_imgs = [s for s in all_images if s['imageset'] == 'trainval']
    val_imgs = [s for s in all_images if s['imageset'] == 'test']

    print('Num train samples {}'.format(len(train_imgs)))
    print('Num val samples {}'.format(len(val_imgs)))

    data_gen_train = data_generators.get_anchor_gt(train_imgs,
                                                   classes_count,
                                                   cfg,
                                                   nn.get_img_output_length,
                                                   K.image_dim_ordering(),
                                                   mode='train')
    data_gen_val = data_generators.get_anchor_gt(val_imgs,
                                                 classes_count,
                                                 cfg,
                                                 nn.get_img_output_length,
                                                 K.image_dim_ordering(),
                                                 mode='val')

    if K.image_dim_ordering() == 'th':
        input_shape_img = (3, None, None)
    else:
        input_shape_img = (None, None, 3)

    img_input = Input(shape=input_shape_img)
    roi_input = Input(shape=(None, 4))

    # define the base network (resnet here, can be VGG, Inception, etc)
    shared_layers = nn.nn_base(img_input, trainable=True)

    # define the RPN, built on the base layers
    num_anchors = len(cfg.anchor_box_scales) * len(cfg.anchor_box_ratios)
    rpn = nn.rpn(shared_layers, num_anchors)

    classifier = nn.classifier(shared_layers,
                               roi_input,
                               cfg.num_rois,
                               nb_classes=len(classes_count),
                               trainable=True)

    model_rpn = Model(img_input, rpn[:2])
    model_classifier = Model([img_input, roi_input], classifier)

    # this is a model that holds both the RPN and the classifier, used to load/save weights for the models
    model_all = Model([img_input, roi_input], rpn[:2] + classifier)

    try:
        print('loading weights from {}'.format(cfg.base_net_weights))
        model_rpn.load_weights(cfg.model_path, by_name=True)
        model_classifier.load_weights(cfg.model_path, by_name=True)
    except Exception as e:
        print(e)
        print(
            'Could not load pretrained model weights. Weights can be found in the keras application folder '
            'https://github.com/fchollet/keras/tree/master/keras/applications')

    optimizer = Adam(lr=1e-5)
    optimizer_classifier = Adam(lr=1e-5)
    model_rpn.compile(optimizer=optimizer,
                      loss=[
                          losses_fn.rpn_loss_cls(num_anchors),
                          losses_fn.rpn_loss_regr(num_anchors)
                      ])
    model_classifier.compile(
        optimizer=optimizer_classifier,
        loss=[
            losses_fn.class_loss_cls,
            losses_fn.class_loss_regr(len(classes_count) - 1)
        ],
        metrics={'dense_class_{}'.format(len(classes_count)): 'accuracy'})
    model_all.compile(optimizer='sgd', loss='mae')

    epoch_length = 1000
    num_epochs = int(cfg.num_epochs)
    iter_num = 0

    losses = np.zeros((epoch_length, 5))
    rpn_accuracy_rpn_monitor = []
    rpn_accuracy_for_epoch = []
    start_time = time.time()

    best_loss = np.Inf

    class_mapping_inv = {v: k for k, v in class_mapping.items()}
    print('Starting training')

    vis = True

    rpncallback = TensorBoard(log_dir='./logs/rpn',
                              histogram_freq=1,
                              batch_size=1,
                              write_graph=True,
                              write_grads=False,
                              write_images=False,
                              embeddings_freq=0,
                              embeddings_layer_names=None,
                              embeddings_metadata=None,
                              embeddings_data=None,
                              update_freq='epoch')
    classifiercallback = TensorBoard(log_dir='./logs/classifier',
                                     histogram_freq=1,
                                     batch_size=1,
                                     write_graph=True,
                                     write_grads=False,
                                     write_images=False,
                                     embeddings_freq=0,
                                     embeddings_layer_names=None,
                                     embeddings_metadata=None,
                                     embeddings_data=None,
                                     update_freq='epoch')
    allcallback = TensorBoard(log_dir='./logs/all',
                              histogram_freq=1,
                              batch_size=1,
                              write_graph=True,
                              write_grads=False,
                              write_images=False,
                              embeddings_freq=0,
                              embeddings_layer_names=None,
                              embeddings_metadata=None,
                              embeddings_data=None,
                              update_freq='epoch')

    for epoch_num in range(num_epochs):

        progbar = generic_utils.Progbar(epoch_length)
        print('Epoch {}/{}'.format(epoch_num + 1, num_epochs))

        while True:
            try:

                if len(rpn_accuracy_rpn_monitor
                       ) == epoch_length and cfg.verbose:
                    mean_overlapping_bboxes = float(
                        sum(rpn_accuracy_rpn_monitor)) / len(
                            rpn_accuracy_rpn_monitor)
                    rpn_accuracy_rpn_monitor = []
                    print(
                        'Average number of overlapping bounding boxes from RPN = {} for {} previous iterations'
                        .format(mean_overlapping_bboxes, epoch_length))
                    if mean_overlapping_bboxes == 0:
                        print(
                            'RPN is not producing bounding boxes that overlap'
                            ' the ground truth boxes. Check RPN settings or keep training.'
                        )

                X, Y, img_data = next(data_gen_train)

                loss_rpn = model_rpn.train_on_batch(X, Y)

                P_rpn = model_rpn.predict_on_batch(X)

                result = roi_helpers.rpn_to_roi(P_rpn[0],
                                                P_rpn[1],
                                                cfg,
                                                K.image_dim_ordering(),
                                                use_regr=True,
                                                overlap_thresh=0.7,
                                                max_boxes=300)
                # note: calc_iou converts from (x1,y1,x2,y2) to (x,y,w,h) format
                X2, Y1, Y2, IouS = roi_helpers.calc_iou(
                    result, img_data, cfg, class_mapping)

                if X2 is None:
                    rpn_accuracy_rpn_monitor.append(0)
                    rpn_accuracy_for_epoch.append(0)
                    continue

                neg_samples = np.where(Y1[0, :, -1] == 1)
                pos_samples = np.where(Y1[0, :, -1] == 0)

                if len(neg_samples) > 0:
                    neg_samples = neg_samples[0]
                else:
                    neg_samples = []

                if len(pos_samples) > 0:
                    pos_samples = pos_samples[0]
                else:
                    pos_samples = []

                rpn_accuracy_rpn_monitor.append(len(pos_samples))
                rpn_accuracy_for_epoch.append((len(pos_samples)))

                if cfg.num_rois > 1:
                    if len(pos_samples) < cfg.num_rois // 2:
                        selected_pos_samples = pos_samples.tolist()
                    else:
                        selected_pos_samples = np.random.choice(
                            pos_samples, cfg.num_rois // 2,
                            replace=False).tolist()
                    try:
                        selected_neg_samples = np.random.choice(
                            neg_samples,
                            cfg.num_rois - len(selected_pos_samples),
                            replace=False).tolist()
                    except:
                        selected_neg_samples = np.random.choice(
                            neg_samples,
                            cfg.num_rois - len(selected_pos_samples),
                            replace=True).tolist()

                    sel_samples = selected_pos_samples + selected_neg_samples
                else:
                    # in the extreme case where num_rois = 1, we pick a random pos or neg sample
                    selected_pos_samples = pos_samples.tolist()
                    selected_neg_samples = neg_samples.tolist()
                    if np.random.randint(0, 2):
                        sel_samples = random.choice(neg_samples)
                    else:
                        sel_samples = random.choice(pos_samples)

                loss_class = model_classifier.train_on_batch(
                    [X, X2[:, sel_samples, :]],
                    [Y1[:, sel_samples, :], Y2[:, sel_samples, :]])

                losses[iter_num, 0] = loss_rpn[1]
                losses[iter_num, 1] = loss_rpn[2]

                losses[iter_num, 2] = loss_class[1]
                losses[iter_num, 3] = loss_class[2]
                losses[iter_num, 4] = loss_class[3]

                iter_num += 1

                progbar.update(
                    iter_num,
                    [('rpn_cls', np.mean(losses[:iter_num, 0])),
                     ('rpn_regr', np.mean(losses[:iter_num, 1])),
                     ('detector_cls', np.mean(losses[:iter_num, 2])),
                     ('detector_regr', np.mean(losses[:iter_num, 3]))])

                if iter_num == epoch_length:
                    loss_rpn_cls = np.mean(losses[:, 0])
                    loss_rpn_regr = np.mean(losses[:, 1])
                    loss_class_cls = np.mean(losses[:, 2])
                    loss_class_regr = np.mean(losses[:, 3])
                    class_acc = np.mean(losses[:, 4])

                    mean_overlapping_bboxes = float(sum(
                        rpn_accuracy_for_epoch)) / len(rpn_accuracy_for_epoch)
                    rpn_accuracy_for_epoch = []

                    if cfg.verbose:
                        print(
                            'Mean number of bounding boxes from RPN overlapping ground truth boxes: {}'
                            .format(mean_overlapping_bboxes))
                        print(
                            'Classifier accuracy for bounding boxes from RPN: {}'
                            .format(class_acc))
                        print('Loss RPN classifier: {}'.format(loss_rpn_cls))
                        print('Loss RPN regression: {}'.format(loss_rpn_regr))
                        print('Loss Detector classifier: {}'.format(
                            loss_class_cls))
                        print('Loss Detector regression: {}'.format(
                            loss_class_regr))
                        print('Elapsed time: {}'.format(time.time() -
                                                        start_time))

                    curr_loss = loss_rpn_cls + loss_rpn_regr + loss_class_cls + loss_class_regr
                    iter_num = 0
                    start_time = time.time()

                    if curr_loss < best_loss:
                        if cfg.verbose:
                            print(
                                'Total loss decreased from {} to {}, saving weights'
                                .format(best_loss, curr_loss))
                        best_loss = curr_loss
                        model_all.save_weights(cfg.model_path)

                    break

            except Exception as e:
                print('Exception: {}'.format(e))
                # save model
                model_all.save_weights(cfg.model_path)
                continue
    print('Training complete, exiting.')