def train_factory(MODEL_NAME): config = tf.ConfigProto() config.gpu_options.allocator_type = 'BFC' config.gpu_options.allow_growth = True set_session(tf.Session(config=config)) # model = CCR(input_shape=(img_width,img_height,1),classes=charset_size) # model = LeNet.build(width=img_width, height=img_height, depth=1, classes=charset_size) # model = ResNet.build_model(SHAPE=(img_width,img_height,1), classes=charset_size) # vgg net 5 # MODEL_PATH='trained_model/vggnet5.hdf5' # model=VGGNet5.vgg(input_shape=(img_width,img_height,1),classes=charset_size) model=None if(MODEL_NAME=='inception_resnet_v2'): model=InceptionResNetV2.inception_resnet_v2(input_shape=(img_width,img_height,3),classes=charset_size,weights='./trained_model/inception_resnet_v2/inception_resnet_v2.12-0.8244.hdf5') elif(MODEL_NAME=='xception'): # xeception model=Xception.Xception((img_width,img_height,3),classes=charset_size) elif(MODEL_NAME=='mobilenet_v2'): #mobilenet v2 model=MobileNetv2.MobileNet_v2((img_width,img_height,3),classes=charset_size) elif(MODEL_NAME=='inception_v3'): #mobilenet v2 model=Inception_v3.inception((img_width,img_height,3),classes=charset_size) elif(MODEL_NAME=='vgg16'): model=VGGNet.vgg(input_shape=(img_width,img_height,3),classes=charset_size) elif(MODEL_NAME=='vgg19'): model=VGG19.VGG19(input_shape=(img_width,img_height,3),classes=charset_size,weights='weights/vgg19_weights_tf_dim_ordering_tf_kernels.h5') elif(MODEL_NAME=='resnet50'): model=ResNet50.resnet(input_shape=(img_width,img_height,3),classes=charset_size) elif(MODEL_NAME=='inception_v4'): model=inception_v4.inception_v4(input_shape=(img_width,img_height,3),classes=charset_size) elif(MODEL_NAME=='resnet34'): model=ResNet34.ResNet34(input_shape=(img_width,img_height,3),classes=charset_size) elif(MODEL_NAME=='densenet121'): model=DenseNet.DenseNet(input_shape=(img_width,img_height,3),classes=charset_size) elif(MODEL_NAME=='densenet161'): model=DenseNet.DenseNet(input_shape=(img_width,img_height,3),classes=charset_size) elif(MODEL_NAME=='shufflenet_v2'): model=ShuffleNetV2.ShuffleNetV2(input_shape=(img_width,img_height,3),classes=charset_size) elif(MODEL_NAME=='resnet_attention_56'): model=Resnet_Attention_56.Resnet_Attention_56(input_shape=(img_width,img_height,3),classes=charset_size) elif(MODEL_NAME=='squeezenet'): model=SqueezeNet.SqueezeNet(input_shape=(img_width,img_height,3),classes=charset_size) elif(MODEL_NAME=='seresnet50'): model=SEResNet50.SEResNet50(input_shape=(img_width,img_height,3),classes=charset_size) elif(MODEL_NAME=='se_resnext'): model=SEResNext.SEResNext(input_shape=(img_width,img_height,3),classes=charset_size) elif(MODEL_NAME=='nasnet'): model=NASNet.NASNetLarge(input_shape=(img_width,img_height,3),classes=charset_size) elif(MODEL_NAME=='custom'): model=Custom_Network.Custom_Network(input_shape=(img_width,img_height,3),classes=charset_size) elif(MODEL_NAME=='resnet18'): model=ResnetBuilder.build_resnet_18(input_shape=(img_width,img_height,3),num_outputs=charset_size) print(model.summary()) train(model,MODEL_NAME)
num_epochs=configs['num_epochs'] data_path='data/HARRISON/preprocessed_data/' # image_path='/home/eric/data/social_images/' image_path='/home/eric/data/HARRISON/' model_name='resnet50' object_image_features_filename='resnet50_image_name_to_features.h5' generator=Generator(data_path=data_path,batch_size=batch_size,image_features_filename=object_image_features_filename) config = tf.ConfigProto() config.gpu_options.allocator_type = 'BFC' config.gpu_options.allow_growth = True set_session(tf.Session(config=config)) # model=VGGNet.vgg(input_shape=(224,224,3),classes=generator.VOCABULARY_SIZE) model=ResNet50.resnet(input_shape=(configs['img_width'],configs['img_height'],configs['channel']), classes=generator.VOCABULARY_SIZE) model.compile(optimizer='adam', loss='binary_crossentropy', metrics=["accuracy",fmeasure,precision,recall]) print(model.summary()) plot_model(model,show_shapes=True,to_file=model_name+'.png') save_path=os.path.join('trained_model',model_name) if(not os.path.exists(save_path)): os.makedirs(save_path) model_names = (os.path.join(save_path,'hashtag_weights.{epoch:02d}-{val_loss:.4f}.hdf5')) model_checkpoint = ModelCheckpoint(filepath=model_names, monitor='val_loss',verbose=1, save_best_only=True, mode='max') callbacks = [model_checkpoint] num_training_samples = generator.training_dataset.shape[0] num_validation_samples = generator.validation_dataset.shape[0] print('Number of training samples:', num_training_samples)