Exemple #1
0
# define the RPN, built on the base layers
num_anchors = len(cfg.anchor_box_scales) * len(cfg.anchor_box_ratios)
rpn = nn.rpn(shared_layers,
             num_anchors)  ##这里应该是只是做了两层简单的卷积,没有anchor的引入,anchor的体现应在损失函数中.
# 返回是一个list,包括了rpn的分类和回国的只.

classifier = nn.classifier(shared_layers,
                           roi_input,
                           cfg.num_rois,
                           nb_classes=len(classes_count),
                           trainable=True)  # 主要这里的nb_classes改程序的时候要主要
# 这里roiinput 似乎是作为一个输入,看下面怎么弄的e
#bird_classifier_output = nn.fg_classifier(shared_layers,bird_rois_input0,bird_rois_input1,bird_rois_input2,bird_rois_input3,bird_rois_input4,bird_rois_input5,bird_rois_input6,nb_classes=200, trainable=True)
holyclass_out = nn.fine_layer(shared_layers,
                              part_roi_input,
                              nb_classes=200,
                              trainable=True)
'''
head_roi = Input(shape=(None, 4))
legs_roi  = Input(shape=(None, 4))
wings_roi = Input(shape=(None, 4))
back_roi = Input(shape=(None, 4))
belly_roi = Input(shape=(None, 4))
breast_roi =  Input(shape=(None, 4))
tail_roi = Input(shape=(None, 4))
bird_roi_input = [head_roi,legs_roi,wings_roi,back_roi,belly_roi,breast_roi,tail_roi]

head_classifier = nn.fg_classifier(shared_layers,head_roi , cfg.num_rois,nb_classes=10)
legs_classifier = nn.fg_classifier(shared_layers, legs_roi, cfg.num_rois,nb_classes=10)
wings_classifier = nn.fg_classifier(shared_layers, wings_roi, cfg.num_rois,nb_classes=10)
back_classifier = nn.fg_classifier(shared_layers, back_roi, cfg.num_rois,nb_classes=10)
val_imgs = [s for s in all_imgs if s['imageset'] == 'test']
print('Num train samples {}'.format(len(train_imgs)))
print('Num val samples {}'.format(len(val_imgs)))

#data_gen_train = data_generators.get_anchor_gt(train_imgs, classes_count, cfg, nn.get_img_output_length,K.image_dim_ordering(), mode='train')
#data_gen_val = data_generators.get_anchor_gt(val_imgs, classes_count, cfg, nn.get_img_output_length,K.image_dim_ordering(), mode='val')
input_shape_img = (None, None, 3)

img_input = Input(shape=input_shape_img)
#roi_input = Input(shape=(None, 4))  # roiinput是什么,要去看看清楚
part_roi_input = Input(shape=[None, 4])

# define the base network (resnet here, can be VGG, Inception, etc)
shared_layers = nn.nn_base(img_input, trainable=True)  # 共享网络层的输出.要明确输出的size
#bird_classifier_output = nn.fg_classifier(shared_layers,bird_rois_input0,bird_rois_input1,bird_rois_input2,bird_rois_input3,bird_rois_input4,bird_rois_input5,bird_rois_input6,nb_classes=200, trainable=True)
holyclass_out = nn.fine_layer(shared_layers, part_roi_input, nb_classes=200)

#class_holyimg_out = nn.fine_layer_hole(shared_layers, part_roi_input,num_rois=1,nb_classes=200)

model_holyclassifier = Model([img_input, part_roi_input], holyclass_out)
#model_classifier_holyimg = Model([img_input,part_roi_input],class_holyimg_out)
cfg.base_net_weights = '/media/e813/D/weights/kerash5/frcnn/TST_holy_img/model_part{}.hdf5'.format(
    8)
cfg.base_net_weights = cfg.weigth_to_save_load(16)
cfg.base_net_weights = '/media/e813/D/weights/kerash5/cac/xunlian_ori/model_part27.hdf5'
try:
    print('loading weights from {}'.format(cfg.base_net_weights))
    #model_rpn.load_weights(cfg.base_net_weights, by_name=True)
    #model_classifier.load_weights(cfg.base_net_weights, by_name=True)
    #model_birdclassifier.load_weights(cfg.base_net_weights, by_name=True)
    model_holyclassifier.load_weights(cfg.base_net_weights, by_name=True)