예제 #1
0
 def bulid_bcnn_1(self):
     input_tensor = Input(shape=self.input_shape)
     if self.base_model == 'inception_v3':
         self.basenet1 = inception_v3.InceptionV3(input_shape=self.input_shape, classes=self.num_classes,
                                                  input_tensor=input_tensor, include_top=False, weights='imagenet')
         self.basenet2 = self.basenet1
         self.model_detector = self.basenet1
         self.model_extractor = self.basenet2
         self.output = self.build_bilinear_cnn(last_conv_layer=-1)
     elif self.base_model == 'resnet_50':
         self.basenet1 = resnet_50.ResNet50(include_top=False, weights='imagenet',
                                            input_tensor=input_tensor, input_shape=self.input_shape)
         self.basenet2 = self.basenet1
         self.model_detector = self.basenet1
         self.model_extractor = self.basenet2
         self.output = self.build_bilinear_cnn(last_conv_layer=-2)
     elif self.base_model == 'inception_resnet_v2':
         self.basenet1 = inception_resnet_v2.InceptionResNetV2(include_top=False, weights='imagenet',
                                                               input_shape=self.input_shape,
                                                               input_tensor=input_tensor)
         self.basenet2 = self.basenet1
         self.model_detector = self.basenet1
         self.model_extractor = self.basenet2
         self.output = self.build_bilinear_cnn(last_conv_layer=-1)
     self.model = Model(input_tensor, self.output)
예제 #2
0
def get_model():
    if model_index == 0:
        return mobilenet_v1.MobileNetV1()
    elif model_index == 1:
        return mobilenet_v2.MobileNetV2()
    elif model_index == 2:
        return mobilenet_v3_large.MobileNetV3Large()
    elif model_index == 3:
        return mobilenet_v3_small.MobileNetV3Small()
    elif model_index == 4:
        return efficientnet.efficient_net_b0()
    elif model_index == 5:
        return efficientnet.efficient_net_b1()
    elif model_index == 6:
        return efficientnet.efficient_net_b2()
    elif model_index == 7:
        return efficientnet.efficient_net_b3()
    elif model_index == 8:
        return efficientnet.efficient_net_b4()
    elif model_index == 9:
        return efficientnet.efficient_net_b5()
    elif model_index == 10:
        return efficientnet.efficient_net_b6()
    elif model_index == 11:
        return efficientnet.efficient_net_b7()
    elif model_index == 12:
        return resnext.ResNeXt50()
    elif model_index == 13:
        return resnext.ResNeXt101()
    elif model_index == 14:
        return inception_v4.InceptionV4()
    elif model_index == 15:
        return inception_resnet_v1.InceptionResNetV1()
    elif model_index == 16:
        return inception_resnet_v2.InceptionResNetV2()
예제 #3
0
def get_model():
    if model_index == 0:
        return mobilenet_v1.MobileNetV1()
    elif model_index == 1:
        return mobilenet_v2.MobileNetV2()
    elif model_index == 2:
        return mobilenet_v3_large.MobileNetV3Large()
    elif model_index == 3:
        return mobilenet_v3_small.MobileNetV3Small()
    elif model_index == 4:
        return efficientnet.efficient_net_b0()
    elif model_index == 5:
        return efficientnet.efficient_net_b1()
    elif model_index == 6:
        return efficientnet.efficient_net_b2()
    elif model_index == 7:
        return efficientnet.efficient_net_b3()
    elif model_index == 8:
        return efficientnet.efficient_net_b4()
    elif model_index == 9:
        return efficientnet.efficient_net_b5()
    elif model_index == 10:
        return efficientnet.efficient_net_b6()
    elif model_index == 11:
        return efficientnet.efficient_net_b7()
    elif model_index == 12:
        return resnext.ResNeXt50()
    elif model_index == 13:
        return resnext.ResNeXt101()
    elif model_index == 14:
        return inception_v4.InceptionV4()
    elif model_index == 15:
        return inception_resnet_v1.InceptionResNetV1()
    elif model_index == 16:
        return inception_resnet_v2.InceptionResNetV2()
    elif model_index == 17:
        return se_resnet.se_resnet_50()
    elif model_index == 18:
        return se_resnet.se_resnet_101()
    elif model_index == 19:
        return se_resnet.se_resnet_152()
    elif model_index == 20:
        return squeezenet.SqueezeNet()
    elif model_index == 21:
        return densenet.densenet_121()
    elif model_index == 22:
        return densenet.densenet_169()
    elif model_index == 23:
        return densenet.densenet_201()
    elif model_index == 24:
        return densenet.densenet_264()
    elif model_index == 25:
        return shufflenet_v2.shufflenet_0_5x()
    elif model_index == 26:
        return shufflenet_v2.shufflenet_1_0x()
    elif model_index == 27:
        return shufflenet_v2.shufflenet_1_5x()
    elif model_index == 28:
        return shufflenet_v2.shufflenet_2_0x()
예제 #4
0
 def build_attention(self):
     input_tensor = Input(shape=self.input_shape)
     if self.base_model == 'inception_v3':
         self.model = inception_v3.InceptionV3(input_shape=self.input_shape, classes=self.num_classes, include_top=True,
                                               weights=None,
                                               input_tensor=input_tensor, attention_module=self.attention_module)
         self.model.load_weights('inception_v3_weights_tf_dim_ordering_tf_kernels_notop.h5',
                                 by_name=True)
     elif self.base_model == 'resnet_50':
         self.model = resnet_50.ResNet50(input_shape=self.input_shape, classes=self.num_classes, include_top=True,
                                         weights=None,
                                         input_tensor=input_tensor, attention_module=self.attention_module)
         self.model.load_weights('resnet_50_weights_tf_dim_ordering_tf_kernels_notop.h5',
                                 by_name=True)
     elif self.base_model == 'inception_resnet_v2':
         self.model = inception_resnet_v2.InceptionResNetV2(input_shape=self.input_shape, classes=self.num_classes,
                                                            include_top=True, weights=None, input_tensor=input_tensor,
                                                            attention_module=self.attention_module)
         self.model.load_weights('inception_resnet_v2_weights_tf_dim_ordering_tf_kernels_notop.h5',
                                 by_name=True)
예제 #5
0
 def bulid_bcnn_0(self):
     input_tensor = Input(shape=self.input_shape)
     if self.base_model == 'inception_v3':
         self.basenet1 = inception_v3.InceptionV3(input_shape=self.input_shape, classes=self.num_classes,
                                                  input_tensor=input_tensor, include_top=False, weights='imagenet')
         self.model_detector = self.basenet1
         x = GlobalAveragePooling2D(name='avg_pool')(self.basenet1.output)
         self.output = Dense(self.num_classes, activation='softmax', name='predictions')(x)
     elif self.base_model == 'resnet_50':
         self.basenet1 = resnet_50.ResNet50(include_top=False, input_tensor=input_tensor, weights='imagenet',
                                            input_shape=self.input_shape)
         self.model_detector = self.basenet1
         x = Flatten()(self.basenet1.output)
         self.output = Dense(self.num_classes, activation='softmax')(x)
     elif self.base_model == 'inception_resnet_v2':
         self.basenet1 = inception_resnet_v2.InceptionResNetV2(input_shape=self.input_shape, classes=self.num_classes,
                                                               include_top=False, weights='imagenet',
                                                               input_tensor=input_tensor)
         self.model_detector = self.basenet1
         x = GlobalAveragePooling2D(name='avg_pool')(self.basenet1.output)
         self.output = Dense(self.num_classes, activation='softmax', name='predictions')(x)
     self.model = Model(input_tensor, self.output)
예제 #6
0
    parse = argparse.ArgumentParser(description='command for testing keras model with fp16 and fp32')
    parse.add_argument('--model', type=str, default='mobilenet', help='support vgg16, resnet50, densenet121, \
     inceptionv3, inception_resnet, xception, mobilenet, squeezenet')
    parse.add_argument('--dtype', type=str, default='float32')
    parse.add_argument('--alpha', type=float, default=0.5, help='alpha for mobilenet')
    args = parse.parse_args()

    K.set_floatx(args.dtype)

    # create model
    if args.model == 'vgg':
        model = vgg16.VGG16(input_shape=(224, 224, 3))
    elif args.model == 'inception':
        model = inception_v3.InceptionV3(input_shape=(299, 299, 3))
    elif args.model == 'inception_resnet':
        model = inception_resnet_v2.InceptionResNetV2(input_shape=(299, 299, 3))
    elif args.model == 'xception':
        model = xception.Xception(input_shape=(299, 299, 3))
    elif args.model == 'resnet':
        model = resnet_50.ResNet50(input_shape=(224, 224, 3))
    elif args.model == 'densenet':
        model = densenet_121.DenseNet(reduction=0.5, classes=1000)
    elif args.model == 'squeezenet':
        model = squeezenet.SqueezeNet(input_shape=(227, 227, 3), classes=1000)
    elif args.model == 'mobilenet':
        model = mobilenet.MobileNet(input_shape=(224, 224, 3), alpha=args.alpha)
    else:
        raise ValueError("Do not support {}".format(args.model))
    model.summary()
    model_name = args.model if args.model != 'mobilenet' else args.model + '_' + str(args.alpha)
    model.load_weights('./weights/{}'.format(weights[model_name]), by_name=True)
예제 #7
0
def get_model():
    if model_index == 0:
        return mobilenet_v1.MobileNetV1()
    elif model_index == 1:
        return mobilenet_v2.MobileNetV2()
    elif model_index == 2:
        return mobilenet_v3_large.MobileNetV3Large()
    elif model_index == 3:
        return mobilenet_v3_small.MobileNetV3Small()
    elif model_index == 4:
        return efficientnet.efficient_net_b0()
    elif model_index == 5:
        return efficientnet.efficient_net_b1()
    elif model_index == 6:
        return efficientnet.efficient_net_b2()
    elif model_index == 7:
        return efficientnet.efficient_net_b3()
    elif model_index == 8:
        return efficientnet.efficient_net_b4()
    elif model_index == 9:
        return efficientnet.efficient_net_b5()
    elif model_index == 10:
        return efficientnet.efficient_net_b6()
    elif model_index == 11:
        return efficientnet.efficient_net_b7()
    elif model_index == 12:
        return resnext.ResNeXt50()
    elif model_index == 13:
        return resnext.ResNeXt101()
    elif model_index == 14:
        return inception_v4.InceptionV4()
    elif model_index == 15:
        return inception_resnet_v1.InceptionResNetV1()
    elif model_index == 16:
        return inception_resnet_v2.InceptionResNetV2()
    elif model_index == 17:
        return se_resnet.se_resnet_50()
    elif model_index == 18:
        return se_resnet.se_resnet_101()
    elif model_index == 19:
        return se_resnet.se_resnet_152()
    elif model_index == 20:
        return squeezenet.SqueezeNet()
    elif model_index == 21:
        return densenet.densenet_121()
    elif model_index == 22:
        return densenet.densenet_169()
    elif model_index == 23:
        return densenet.densenet_201()
    elif model_index == 24:
        return densenet.densenet_264()
    elif model_index == 25:
        return shufflenet_v2.shufflenet_0_5x()
    elif model_index == 26:
        return shufflenet_v2.shufflenet_1_0x()
    elif model_index == 27:
        return shufflenet_v2.shufflenet_1_5x()
    elif model_index == 28:
        return shufflenet_v2.shufflenet_2_0x()
    elif model_index == 29:
        return resnet.resnet_18()
    elif model_index == 30:
        return resnet.resnet_34()
    elif model_index == 31:
        return resnet.resnet_50()
    elif model_index == 32:
        return resnet.resnet_101()
    elif model_index == 33:
        return resnet.resnet_152()
    elif model_index == 34:
        return vgg16.VGG16()
    elif model_index == 35:
        return vgg16_mini.VGG16()
    elif model_index == 36:
        return VGG16_self.VGG16()
    elif model_index == 10086:
        return diy_resnet.resnet_50()
    else:
        raise ValueError("The model_index does not exist.")