def iterative_prediction(image_path, image_name, result_image_path, result_path, iterative_time): #copy original image to result_image_path if not os.path.exists(result_image_path): os.mkdir(result_image_path) if not os.path.exists(result_path): os.mkdir(result_path) image_name, image_ext = os.path.splitext(image_name) shutil.copyfile(os.path.join(image_path, image_name+image_ext), os.path.join(result_image_path, image_name+'+0'+image_ext)) results = [] #establish network east = East() east_detect = east.east_network() east_detect.load_weights(cfg.saved_model_weights_file_path) #start iteration for i in range(1, iterative_time + 1, 1): #load i-th result as input to predict current_result = predict(east_detect, os.path.join(result_image_path,image_name + '+' + str(i-1)+image_ext) , text_pixel_threshold= cfg.text_pixel_threshold - i * 0.05, text_side_threshold= cfg.text_side_vertex_pixel_threshold - i * 0.05, text_trunc_threshold= cfg.text_trunc_threshold + i * 0.03, action_pixel_threshold= cfg.action_pixel_threshold, arrow_trunc_threshold= cfg.arrow_trunc_threshold, nock_trunc_threshold= cfg.nock_trunc_threshold, quiet=False) results.extend(current_result) #according to current results, erase last iteration image erase_image(current_result,os.path.join(result_image_path,image_name + '+' + str(i-1)+image_ext)) del current_result with open(os.path.join(result_path, image_name+'.txt'),'w') as result_fp: result_fp.writelines(results) del results
def main(img_path): east = East() east_detect = east.east_network() east_detect.load_weights(cfg.saved_model_weights_file_path) for i in range(int(cfg.max_self_iteration)): predict(east_detect, img_path, i, quiet=True) #load predicted images and results with open(img_path[:-4] + '.txt') as result_fp: results = result_fp.readlines() if i != 0: image_path = img_path[:-4] + '_' + str(i - 1) + '.png' else: image_path = img_path input_img = cv2.imread(image_path) ori_img = cv2.imread(img_path) #with Image.open(image_path) as im: #im_array = im.img_to_array() # draw = ImageDraw.Draw(im.convert('RGB')) for line in results: geo = line.split(',') # erase detective text cv2.rectangle(input_img, (int( (geo[0])) - cfg.erase_offset_pixels, int( (geo[1])) - cfg.erase_offset_pixels), (int( (geo[4])) + cfg.erase_offset_pixels, int( (geo[5])) + cfg.erase_offset_pixels), (255, 255, 255), thickness=-1) #marked predict results cv2.rectangle(ori_img, (int((geo[0])), int((geo[1]))), (int( (geo[4])), int((geo[5]))), (255, 0, 0), thickness=1) cv2.imwrite(img_path[:-4] + '_' + str(i) + '.png', input_img) cv2.imwrite(img_path[:-4] + '_predict_' + str(i) + '.png', ori_img)
def train_model(): east = East() east_network = east.east_network() east_network.summary() east_network.compile( loss=quad_loss, optimizer=Nadam( lr=cfg.lr, # clipvalue=cfg.clipvalue, schedule_decay=cfg.decay)) # load pre-trained model if cfg.load_weights and os.path.exists(cfg.saved_model_weights_file_path): east_network.load_weights(cfg.saved_model_weights_file_path, by_name=True, skip_mismatch=True) print('start training task:' + cfg.train_task_id + '....................') # train on current scale data east_network.fit_generator( generator=gen(), steps_per_epoch=cfg.steps_per_epoch, epochs=cfg.epoch_num, validation_data=gen(is_val=True), validation_steps=cfg.validation_steps, verbose=2, initial_epoch=cfg.initial_epoch, callbacks=[ EarlyStopping(patience=cfg.patience, verbose=2), ModelCheckpoint(filepath=cfg.saved_model_weights_file_path, save_best_only=True, save_weights_only=True, verbose=1) ]) del east_network, east
def parse_args(): parser = argparse.ArgumentParser() parser.add_argument('--path', '-p', default='../data/003.jpg', help='image path') parser.add_argument('--threshold', '-t', default=cfg.pixel_threshold, help='pixel activation threshold') return parser.parse_args() if __name__ == '__main__': args = parse_args() img_path = args.path threshold = float(args.threshold) print(img_path, threshold) img = image.load_img(img_path) im_name = img_path.split('/')[-1][:-4] east = East() east_detect = east.east_network() east_detect.load_weights(cfg.saved_model_weights_file_path) predict(east_detect, img_path, threshold) # text_recs_all, text_recs_len, img_all = predict_quad(east_detect, img, img_name=im_name) # print(text_recs_all) # print("-------------------------") # print(text_recs_len) # print("-------------------------") # print(img_all) # img = image.array_to_img(img_all[0]) # img.show()
import os from keras.callbacks import EarlyStopping, ModelCheckpoint from keras.optimizers import Adam import cfg from network import East from losses import quad_loss from data_generator import gen east = East() east_network = east.east_network() east_network.summary() east_network.compile(loss=quad_loss, optimizer=Adam(lr=cfg.lr, # clipvalue=cfg.clipvalue, decay=cfg.decay)) if cfg.load_weights and os.path.exists(cfg.saved_model_weights_file_path): east_network.load_weights(cfg.saved_model_weights_file_path) east_network.fit_generator(generator=gen(), steps_per_epoch=cfg.steps_per_epoch, epochs=cfg.epoch_num, validation_data=gen(is_val=True), validation_steps=cfg.validation_steps, verbose=1, initial_epoch=cfg.initial_epoch, callbacks=[ EarlyStopping(patience=cfg.patience, verbose=1), ModelCheckpoint(filepath=cfg.model_weights_path, save_best_only=True, save_weights_only=True, verbose=1)])
def start_prediction(cropped_path): east = East() east_detect = east.east_network() east_detect.load_weights("./east_model_weights_3T736.h5") predict(east_detect, cropped_path, cfg.pixel_threshold)