def ocr_for_single_line(self, img_fp): """ Recognize characters from an image with characters with only one line :param img_fp: image file path; or gray image mx.nd.NDArray; or gray image np.ndarray :return: charector list, such as ['你', '好'] """ hp = deepcopy(self._hp) if isinstance(img_fp, str) and os.path.isfile(img_fp): img = read_ocr_img(img_fp) elif isinstance(img_fp, mx.nd.NDArray) or isinstance( img_fp, np.ndarray): img = img_fp else: raise TypeError('Inappropriate argument type.') img = rescale_img(img, hp) init_state_names, init_state_arrays = lstm_init_states(batch_size=1, hp=hp) sample = SimpleBatch(data_names=['data'] + init_state_names, data=[mx.nd.array([img])] + init_state_arrays) mod = self._get_module(hp, sample) mod.forward(sample) prob = mod.get_outputs()[0].asnumpy() prediction, start_end_idx = CtcMetrics.ctc_label( np.argmax(prob, axis=-1).tolist()) # print(start_end_idx) alphabet = self._alphabet res = [alphabet[p] for p in prediction] return res
def ocr_for_single_lines(self, img_list): """ Batch recognize characters from a list of one-line-characters images. :param img_list: list of images, in which each element should be a line image array, with type mx.nd.NDArray or np.ndarray. Each element should be a tensor with values ranging from 0 to 255, and with shape [height, width] or [height, width, channel]. The optional channel should be 1 (gray image) or 3 (color image). :return: list of list of chars, such as [['第', '一', '行'], ['第', '二', '行'], ['第', '三', '行']] """ if len(img_list) == 0: return [] img_list = [self._preprocess_img_array(img) for img in img_list] batch_size = len(img_list) img_list, img_widths = self._pad_arrays(img_list) sample = SimpleBatch(data_names=['data'], data=[mx.nd.array(img_list)]) prob = self._predict(sample) # [seq_len, batch_size, num_classes] prob = np.reshape(prob, (-1, batch_size, prob.shape[1])) if self._cand_alph_idx is not None: prob = prob * self._gen_mask(prob.shape) max_width = max(img_widths) res = [] for i in range(batch_size): res.append( self._gen_line_pred_chars(prob[:, i, :], img_widths[i], max_width)) return res
def main(): parser = argparse.ArgumentParser() parser.add_argument("--dataset", help="use which kind of dataset, captcha or cn_ocr", choices=['captcha', 'cn_ocr'], type=str, default='captcha') parser.add_argument("--file", help="Path to the CAPTCHA image file") parser.add_argument("--prefix", help="Checkpoint prefix [Default 'ocr']", default='./models/model') parser.add_argument("--epoch", help="Checkpoint epoch [Default 100]", type=int, default=20) parser.add_argument('--charset_file', type=str, help='存储了每个字对应哪个id的关系.') args = parser.parse_args() if args.dataset == 'cn_ocr': hp = Hyperparams() img = read_ocr_img(args.file, hp) else: hp = Hyperparams2() img = read_captcha_img(args.file, hp) # init_state_names, init_state_arrays = lstm_init_states(batch_size=1, hp=hp) # import pdb; pdb.set_trace() sample = SimpleBatch(data_names=['data'], data=[mx.nd.array([img])]) network = crnn_lstm(hp) mod = load_module(args.prefix, args.epoch, sample.data_names, sample.provide_data, network=network) mod.forward(sample) prob = mod.get_outputs()[0].asnumpy() prediction, start_end_idx = CtcMetrics.ctc_label( np.argmax(prob, axis=-1).tolist()) if args.charset_file: alphabet, _ = read_charset(args.charset_file) res = [alphabet[p] for p in prediction] print("Predicted Chars:", res) else: # Predictions are 1 to 10 for digits 0 to 9 respectively (prediction 0 means no-digit) prediction = [p - 1 for p in prediction] print("Digits:", prediction) return