Пример #1
0
def load_models():
    """Load models """
    model_list = []

    model_list.append(VGGNet("model/vggnet.h5"))
    model_list.append(ResNet("model/resnet.h5"))
    model_list.append(VGGNet5("model/vggnet5.h5"))

    return model_list
Пример #2
0
     if args.model_type == 'ex_atten':
         model = VGG_interpretable_atten(num_classes=num_classe)
     elif args.model_type == 'ex':
         model = VGG_interpretable(num_classes=num_classe)
     elif args.model_type == 'atten':
         model = VGG_atten(num_classes=num_classe)
     elif args.model_type == 'gradcam':
         model = VGG_gradcam(num_classes=num_classe)
     elif args.model_type == 'ex_gradcam':
         model = VGG_interpretable_gradcam(num_classes=num_classe)
     elif args.model_type == 'ex_gradcam5':
         model = VGG_interpretable_gradcam(num_classes=num_classe)
     elif args.model_type == 'ex_gradcam2':
         model = VGG_interpretable_gradcam2(num_classes=num_classe)
     else:
         model = VGGNet(num_classes=num_classe)
         #model = VGG_gradcam(num_classes=num_classe)
 elif args.model == 'resnet':
     if args.model_type == 'ex_atten':
         model = VGG_interpretable_atten(num_classes=num_classe)
     elif args.model_type == 'ex':
         model = Resnet_interpretable(num_classes=num_classe)
     elif args.model_type == 'ex_gradcam':
         model = Resnet_interpretable_gradcam(num_classes=num_classe)
     elif args.model_type == 'ex_gradcam2':
         model = VGG_interpretable_gradcam2(num_classes=num_classe)
     else:
         model = Resnet(num_classes=num_classe)
 elif args.model == 'mobilenet':
     if args.model_type == 'ex_atten':
         model = VGG_interpretable_atten(num_classes=num_classe)
Пример #3
0
def train_factory(MODEL_NAME):

    config = tf.ConfigProto()
    config.gpu_options.allocator_type = 'BFC'
    config.gpu_options.allow_growth = True
    set_session(tf.Session(config=config)) 
    # model = CCR(input_shape=(img_width,img_height,1),classes=charset_size)
    # model = LeNet.build(width=img_width, height=img_height, depth=1, classes=charset_size)
    # model = ResNet.build_model(SHAPE=(img_width,img_height,1), classes=charset_size)

    # vgg net 5
    # MODEL_PATH='trained_model/vggnet5.hdf5'
    # model=VGGNet5.vgg(input_shape=(img_width,img_height,1),classes=charset_size)

    model=None
    if(MODEL_NAME=='inception_resnet_v2'):
        model=InceptionResNetV2.inception_resnet_v2(input_shape=(img_width,img_height,3),classes=charset_size,weights='./trained_model/inception_resnet_v2/inception_resnet_v2.12-0.8244.hdf5')
    elif(MODEL_NAME=='xception'):
        # xeception
        model=Xception.Xception((img_width,img_height,3),classes=charset_size)
    elif(MODEL_NAME=='mobilenet_v2'):
        #mobilenet v2
        model=MobileNetv2.MobileNet_v2((img_width,img_height,3),classes=charset_size)
    elif(MODEL_NAME=='inception_v3'):
        #mobilenet v2
        model=Inception_v3.inception((img_width,img_height,3),classes=charset_size)
    elif(MODEL_NAME=='vgg16'):
        model=VGGNet.vgg(input_shape=(img_width,img_height,3),classes=charset_size)
    elif(MODEL_NAME=='vgg19'):
        model=VGG19.VGG19(input_shape=(img_width,img_height,3),classes=charset_size,weights='weights/vgg19_weights_tf_dim_ordering_tf_kernels.h5')
    elif(MODEL_NAME=='resnet50'):
        model=ResNet50.resnet(input_shape=(img_width,img_height,3),classes=charset_size)
    elif(MODEL_NAME=='inception_v4'):
        model=inception_v4.inception_v4(input_shape=(img_width,img_height,3),classes=charset_size)
    elif(MODEL_NAME=='resnet34'):
        model=ResNet34.ResNet34(input_shape=(img_width,img_height,3),classes=charset_size)
    elif(MODEL_NAME=='densenet121'):
        model=DenseNet.DenseNet(input_shape=(img_width,img_height,3),classes=charset_size)
    elif(MODEL_NAME=='densenet161'):
        model=DenseNet.DenseNet(input_shape=(img_width,img_height,3),classes=charset_size)
    elif(MODEL_NAME=='shufflenet_v2'):
        model=ShuffleNetV2.ShuffleNetV2(input_shape=(img_width,img_height,3),classes=charset_size)
    elif(MODEL_NAME=='resnet_attention_56'):
        model=Resnet_Attention_56.Resnet_Attention_56(input_shape=(img_width,img_height,3),classes=charset_size)
    elif(MODEL_NAME=='squeezenet'):
        model=SqueezeNet.SqueezeNet(input_shape=(img_width,img_height,3),classes=charset_size)
    elif(MODEL_NAME=='seresnet50'):
        model=SEResNet50.SEResNet50(input_shape=(img_width,img_height,3),classes=charset_size)
    elif(MODEL_NAME=='se_resnext'):
        model=SEResNext.SEResNext(input_shape=(img_width,img_height,3),classes=charset_size)
    elif(MODEL_NAME=='nasnet'):
        model=NASNet.NASNetLarge(input_shape=(img_width,img_height,3),classes=charset_size)
    elif(MODEL_NAME=='custom'):
        model=Custom_Network.Custom_Network(input_shape=(img_width,img_height,3),classes=charset_size)
    elif(MODEL_NAME=='resnet18'):
        model=ResnetBuilder.build_resnet_18(input_shape=(img_width,img_height,3),num_outputs=charset_size)



    print(model.summary())
    train(model,MODEL_NAME)
Пример #4
0
trainloader = torch.utils.data.DataLoader(trainset,
                                          batch_size=batch_size,
                                          shuffle=True,
                                          num_workers=2)

valloader = torch.utils.data.DataLoader(valset,
                                        batch_size=batch_size,
                                        shuffle=True,
                                        num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse',
           'ship', 'truck')

# Build model
vgg16 = VGGNet(10).to(device)

# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(vgg16.parameters(), lr=lr, momentum=momentum)

# Training
print("Start training")
best_val_acc = 0.0
for epoch in range(epochs):
    # train part
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)

        optimizer.zero_grad()
Пример #5
0
import numpy as np
import argparse
import glob

from vgg16 import VGGNet

ap = argparse.ArgumentParser()
ap.add_argument("--index", required=True, help="Name of index file")
args = vars(ap.parse_args())

if __name__ == "__main__":

    feats = []
    names = []

    model = VGGNet()
    # i=0
    for imgPath in glob.glob("../../" + "database" + "/*.jpg"):
        imageId = imgPath[imgPath.rfind("/") + 1:]
        norm_feat = model.extract_feat(imgPath)
        feats.append(norm_feat)
        names.append(imageId)
        # print("extracting feature from image No. %d " %((i+1)))
        # i+=1

    feats = np.array(feats)

    output = args["index"]

    h5f = h5py.File(output, 'w')
    h5f.create_dataset('dataset_1', data=feats)
Пример #6
0
testset = torchvision.datasets.CIFAR10(root='./data',
                                       train=False,
                                       download=True,
                                       transform=transform)

testloader = torch.utils.data.DataLoader(testset,
                                         batch_size=batch_size,
                                         shuffle=True,
                                         num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse',
           'ship', 'truck')

# 저장한 모델 불러오기
vgg16 = VGGNet(10).to(device)

optimizer = optim.SGD(vgg16.parameters(), lr=lr, momentum=momentum)

checkpoint = torch.load('vgg16_ckpt.pt')
vgg16.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

# PATH = 'vgg16.pth'
# vgg16.load_state_dict(torch.load(PATH))

# 전체 test 데이터셋에 대한 정확도
correct = 0
total = 0