def image_retrieval(): # Run mode: (autoencoder -> simpleAE, convAE) or (transfer learning -> vgg19) modelName = "IncepResNet" # try: "simpleAE", "convAE", "vgg19" , "IncepResNet" trainModel = True parallel = False # use multicore processing # Make paths dataTrainDir = os.path.join(os.getcwd(), "data", "train") dataTestDir = os.path.join(os.getcwd(), "data", "test") outDir = os.path.join(os.getcwd(), "output", modelName) if not os.path.exists(outDir): os.makedirs(outDir) # Read images extensions = [".jpg", ".jpeg"] print("Reading train images from '{}'...".format(dataTrainDir)) imgs_train = read_imgs_dir(dataTrainDir, extensions, parallel=parallel) print("Reading test images from '{}'...".format(dataTestDir)) imgs_test = read_imgs_dir(dataTestDir, extensions, parallel=parallel) shape_img = imgs_train[0].shape print("Image shape = {}".format(shape_img)) strategy = tf.distribute.MirroredStrategy() # Build models if modelName in ["simpleAE", "convAE", "stackedAE"]: # Set up autoencoder info = { "shape_img": shape_img, "autoencoderFile": os.path.join(outDir, "{}_autoecoder.h5".format(modelName)), "encoderFile": os.path.join(outDir, "{}_encoder.h5".format(modelName)), "decoderFile": os.path.join(outDir, "{}_decoder.h5".format(modelName)), "checkpoint": os.path.join(outDir, "{}_checkpoint.h5".format(modelName)) } model = AutoEncoder(modelName, info) model.set_arch() if modelName == "simpleAE": shape_img_resize = shape_img input_shape_model = (model.encoder.input.shape[1], ) output_shape_model = (model.encoder.output.shape[1], ) n_epochs = 30 elif modelName in ["convAE", "stackedAE"]: shape_img_resize = shape_img input_shape_model = tuple( [int(x) for x in model.encoder.input.shape[1:]]) output_shape_model = tuple( [int(x) for x in model.encoder.output.shape[1:]]) n_epochs = 100 else: raise Exception("Invalid modelName!") elif modelName in ["vgg19", "ResNet50v2", "IncepResNet"]: pretrainedModel = Pretrained_Model(modelName, shape_img) model = pretrainedModel.buildModel() shape_img_resize, input_shape_model, output_shape_model = pretrainedModel.makeInOut( ) else: raise Exception("Invalid modelName!") # Print some model info print("input_shape_model = {}".format(input_shape_model)) print("output_shape_model = {}".format(output_shape_model)) # Apply transformations to all images class ImageTransformer(object): def __init__(self, shape_resize): self.shape_resize = shape_resize def __call__(self, img): img_transformed = resize_img(img, self.shape_resize) img_transformed = normalize_img(img_transformed) return img_transformed transformer = ImageTransformer(shape_img_resize) print("Applying image transformer to training images...") imgs_train_transformed = apply_transformer(imgs_train, transformer, parallel=parallel) print("Applying image transformer to test images...") imgs_test_transformed = apply_transformer(imgs_test, transformer, parallel=parallel) # Convert images to numpy array X_train = np.array(imgs_train_transformed).reshape((-1, ) + input_shape_model) X_test = np.array(imgs_test_transformed).reshape((-1, ) + input_shape_model) print(" -> X_train.shape = {}".format(X_train.shape)) print(" -> X_test.shape = {}".format(X_test.shape)) # Train (if necessary) if modelName in ["simpleAE", "convAE", "stackedAE"]: if trainModel: print('Number of devices: {}'.format( strategy.num_replicas_in_sync)) with strategy.scope(): model.compile(loss="binary_crossentropy", optimizer="adam") early_stopping = EarlyStopping(monitor="val_loss", mode="min", verbose=1, patience=6, min_delta=0.0001) checkpoint = ModelCheckpoint(os.path.join( outDir, "{}_checkpoint.h5".format(modelName)), monitor="val_loss", mode="min", save_best_only=True) model.fit(X_train, n_epochs=n_epochs, batch_size=32, callbacks=[early_stopping, checkpoint]) model.save_models() else: model.load_models(loss="binary_crossentropy", optimizer="adam") # Create embeddings using model print("Inferencing embeddings using pre-trained model...") E_train = model.predict(X_train) E_train_flatten = E_train.reshape((-1, np.prod(output_shape_model))) E_test = model.predict(X_test) E_test_flatten = E_test.reshape((-1, np.prod(output_shape_model))) print(" -> E_train.shape = {}".format(E_train.shape)) print(" -> E_test.shape = {}".format(E_test.shape)) print(" -> E_train_flatten.shape = {}".format(E_train_flatten.shape)) print(" -> E_test_flatten.shape = {}".format(E_test_flatten.shape)) # Make reconstruction visualizations if modelName in ["simpleAE", "convAE", "stackedAE"]: print("Visualizing database image reconstructions...") imgs_train_reconstruct = model.decoder.predict(E_train) if modelName == "simpleAE": imgs_train_reconstruct = imgs_train_reconstruct.reshape( (-1, ) + shape_img_resize) plot_reconstructions(imgs_train, imgs_train_reconstruct, os.path.join( outDir, "{}_reconstruct.png".format(modelName)), range_imgs=[0, 255], range_imgs_reconstruct=[0, 1]) # Fit kNN model on training images print("Fitting k-nearest-neighbour model on training images...") knn = NearestNeighbors(n_neighbors=5, metric="cosine") knn.fit(E_train_flatten) # Perform image retrieval on test images print("Performing image retrieval on test images...") for i, emb_flatten in enumerate(E_test_flatten): # find k nearest train neighbours _, indices = knn.kneighbors([emb_flatten]) img_query = imgs_test[i] # query image imgs_retrieval = [imgs_train[idx] for idx in indices.flatten()] # retrieval images outFile = os.path.join(outDir, "{}_retrieval_{}.png".format(modelName, i)) plot_query_retrieval(img_query, imgs_retrieval, outFile)
E_test_flatten = E_test.reshape((-1, np.prod(output_shape_model))) print(" -> E_train.shape = {}".format(E_train.shape)) print(" -> E_test.shape = {}".format(E_test.shape)) print(" -> E_train_flatten.shape = {}".format(E_train_flatten.shape)) print(" -> E_test_flatten.shape = {}".format(E_test_flatten.shape)) # Make reconstruction visualizations if modelName in ["simpleAE", "convAE"]: print("Visualizing database image reconstructions...") imgs_train_reconstruct = model.decoder.predict(E_train) if modelName == "simpleAE": imgs_train_reconstruct = imgs_train_reconstruct.reshape( (-1, ) + shape_img_resize) plot_reconstructions(imgs_train, imgs_train_reconstruct, os.path.join(outPath, "{}_reconstruct.png".format(modelName)), range_imgs=[0, 255], range_imgs_reconstruct=[0, 1]) # Fit kNN model on training images print("Fitting k-nearest-neighbour model on training images...") knn = NearestNeighbors(n_neighbors=5, metric="cosine") knn.fit(E_train_flatten) # Perform image retrieval on test images print("Performing image retrieval on test images...") for i, emb_flatten in enumerate(E_test_flatten): _, indices = knn.kneighbors([emb_flatten ]) # find k nearest train neighbours img_query = imgs_test[i] # query image imgs_retrieval = [imgs_train[idx]