def train_autokeras(RESIZE_TRAIN_IMG_DIR, RESIZE_TEST_IMG_DIR, TRAIN_CSV_DIR, TEST_CSV_DIR, TIME): # Load images train_data, train_labels = load_image_dataset( csv_file_path=TRAIN_CSV_DIR, images_path=RESIZE_TRAIN_IMG_DIR) # 加载数据 test_data, test_labels = load_image_dataset( csv_file_path=TEST_CSV_DIR, images_path=RESIZE_TEST_IMG_DIR) train_data = train_data.astype('float32') / 255 test_data = test_data.astype('float32') / 255 clf = ImageClassifier(verbose=True) clf.fit(train_data, train_labels, time_limit=TIME) # 找最优模型 clf.final_fit(train_data, train_labels, test_data, test_labels, retrain=True) # 最优模型继续训练 y = clf.evaluate(test_data, test_labels) print("测试集精确度:", y) score = clf.evaluate(train_data, train_labels) # score: 0.8139240506329114 print("训练集精确度:", score) clf.export_keras_model(MODEL_DIR) # 储存
from autokeras.image.image_supervised import ImageClassifier, load_image_dataset # we created ./all path where we copied all the images train_path = '../data/all' train_labels = '../data/labels_train.csv' x_train, y_train = load_image_dataset(csv_file_path=train_labels, images_path=train_path) #x_val, y_val = load_image_dataset(csv_file_path=validation_labels,images_path=validation_path) clf = ImageClassifier(verbose=True) # 4 hours search clf.fit(x_train, y_train, time_limit=4 * 60 * 60) best_model = clf.export_keras_model() keras_model = best_model.produce_keras_model('asdf') keras_model.summary() # save it keras_model.save('best.hdf5') #clf.final_fit(x_train,y_train,x_val,y_val,retrain = True, trainer_args={'max_iter_num':10}) #print(clf.evaluate(x_val,y_val))
# 使用图片识别器 clf = ImageClassifier(verbose=True) # 给其训练数据和标签,训练的最长时间可以设定,假设为1分钟,autokers会不断找寻最优的网络模型 clf.fit(train_data, train_labels, time_limit=1 * 60) # 找到最优模型后,再最后进行一次训练和验证 clf.final_fit(train_data, train_labels, test_data, test_labels, retrain=True) # 给出评估结果 y = clf.evaluate(test_data, test_labels) print("evaluate:", y) # 给一个图片试试预测是否准确 img = load_img(PREDICT_IMG_PATH) x = img_to_array(img) x = x.astype('float32') / 255 x = np.reshape(x, (1, IMAGE_SIZE, IMAGE_SIZE, 3)) print("x shape:", x.shape) # 最后的结果是一个numpy数组,里面是预测值4,意味着是马,说明预测准确 y = clf.predict(x) print("predict:", y) # 导出我们生成的模型 clf.export_keras_model(MODEL_DIR) # 加载模型 model = load_model(MODEL_DIR) # 将模型导出成可视化图片 plot_model(model, to_file=MODEL_PNG)
x_train.append(img) # x_train.reshape(256,256,3) y_train.append(0) x_train = np.array(x_train) y_train = np.array(y_train) for file_name in os.listdir("test/normal"): img = cv2.imread("test/normal/" + file_name) x_test.append(img) # x_train.reshape(256,256,3) y_test.append(0) for file_name in os.listdir("test/anomaly"): img = cv2.imread("test/anomaly/" + file_name) x_test.append(img) # x_train.reshape(256,256,3) y_test.append(0) x_test = np.array(x_test) y_test = np.array(y_test) print(x_train.shape) print(y_train.shape) clf = ImageClassifier(verbose=True) clf.fit(x_train, y_train, time_limit=12 * 60 * 60) clf.final_fit(x_train, y_train, x_test, y_test, retrain=True) clf.export_autokeras_model("./autokeras_model.bin") # Auto-Kerasで読み込めるモデルを保存 clf.export_keras_model("./keras_model.bin") # Kerasで読み込めるモデルを保存 acc = clf.evaluate(x_test, y_test)
y_test = [] base_path = "../data-deep-fashion-women/img/" #Load the data from local file into a dataframe df = pd.read_csv('../data-deep-fashion-women/img/WOMEN/labels_test.csv') print(len(df)) for index, row in df.iterrows(): #print(row[0], row[1]) ss = base_path + row[0] #print(ss) img = image.load_img(ss, target_size=(224, 224)) img_data = image.img_to_array(img) image_data_np = np.array(img_data) x_test.append(image_data_np) y_test.append(row[1]) from autokeras.image.image_supervised import load_image_dataset from autokeras.image.image_supervised import ImageClassifier clf = ImageClassifier(verbose=True) clf.fit(x_train, y_train, time_limit=10 * 60 * 60) # 10 hours clf.final_fit(x_train, y_train, x_test, y_test, retrain=True) y = clf.evaluate(x_test, y_test) print(y) clf.export_autokeras_model('./_models/nas_1.h5') clf.export_keras_model('./_models/nas_2.h5') clf.load_searcher().load_best_model().produce_keras_model().save( './_models/nas_3.h5')