def main_test(args): voc, char2id, id2char = get_vocabulary(voc_type=args.voc_type) input_images = tf.placeholder(dtype=tf.float32, shape=[1, args.height, None, 3], name="input_images") input_images_width = tf.placeholder(dtype=tf.float32, shape=[1], name="input_images_width") input_labels = tf.placeholder(dtype=tf.int32, shape=[1, args.max_len], name="input_labels") sar_model = SARModel(num_classes=len(voc), encoder_dim=args.encoder_sdim, encoder_layer=args.encoder_layers, decoder_dim=args.decoder_sdim, decoder_layer=args.decoder_layers, decoder_embed_dim=args.decoder_edim, seq_len=args.max_len, is_training=False) model_infer, attention_weights, pred = sar_model(input_images, input_labels, input_images_width, batch_size=1, reuse=False) global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), trainable=False, dtype=tf.int32) variable_averages = tf.train.ExponentialMovingAverage(0.997, global_step) saver = tf.train.Saver(variable_averages.variables_to_restore()) with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: ckpt_state = tf.train.get_checkpoint_state(args.checkpoints) model_path = os.path.join(args.checkpoints, os.path.basename(ckpt_state.model_checkpoint_path)) print('Restore from {}'.format(model_path)) saver.restore(sess, model_path) images_path, labels = get_data(args) predicts = [] for img_path, label in zip(images_path, labels): try: img = cv2.imread(img_path) except Exception as e: print("{} error: {}".format(img_path, e)) continue img, la, width = data_preprocess(img, label, char2id, args) pred_value, attention_weights_value = sess.run([pred, attention_weights], feed_dict={input_images: [img], input_labels: [la], input_images_width: [width]}) pred_value_str = idx2label(pred_value, id2char, char2id)[0] print("predict: {} label: {}".format(pred_value_str, label)) predicts.append(pred_value_str) pred_value_str += '$' if args.vis_dir != None and args.vis_dir != "": os.makedirs(args.vis_dir, exist_ok=True) assert len(img.shape) == 3 att_maps = attention_weights_value.reshape([-1, attention_weights_value.shape[2], attention_weights_value.shape[3], 1]) # T * H * W * 1 for i, att_map in enumerate(att_maps): if i >= len(pred_value_str): break att_map = cv2.resize(att_map, (img.shape[1], img.shape[0])) _att_map = np.zeros(dtype=np.uint8, shape=[img.shape[0], img.shape[1], 3]) _att_map[:, :, -1] = (att_map * 255).astype(np.uint8) show_attention = cv2.addWeighted(img, 0.5, _att_map, 2, 0) cv2.imwrite(os.path.join(args.vis_dir, os.path.basename(img_path).split('.')[0] + "_" + str(i) + "_" + pred_value_str[i] + ".jpg"), show_attention) acc_rate = accuracy(predicts, labels) print("Done, Accuracy: {}".format(acc_rate))
def main_test_with_lexicon(args): voc, char2id, id2char = get_vocabulary(voc_type=args.voc_type) input_images = tf.placeholder(dtype=tf.float32, shape=[1, args.height, args.width, 3], name="input_images") input_images_width = tf.placeholder(dtype=tf.float32, shape=[1], name="input_images_width") input_labels = tf.placeholder(dtype=tf.int32, shape=[1, args.max_len], name="input_labels") sar_model = SARModel(num_classes=len(voc), encoder_dim=args.encoder_sdim, encoder_layer=args.encoder_layers, decoder_dim=args.decoder_sdim, decoder_layer=args.decoder_layers, decoder_embed_dim=args.decoder_edim, seq_len=args.max_len, is_training=False) # encoder_state, feature_map, mask_map = sar_model.inference(input_images, input_images_width, 1, reuse=True) # model_infer, attention_weights, pred = sar_model.decode(encoder_state, feature_map, input_labels, mask_map, reuse=True, decode_type=args.decode_type) model_infer, attention_weights, pred, _ = sar_model(input_images, input_labels, input_images_width, batch_size=1, reuse=False) global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), trainable=False, dtype=tf.int32) variable_averages = tf.train.ExponentialMovingAverage(0.997, global_step) saver = tf.train.Saver(variable_averages.variables_to_restore()) # saver = tf.train.Saver(tf.global_variables()) with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: ckpt_state = tf.train.get_checkpoint_state(args.checkpoints) model_path = os.path.join(args.checkpoints, os.path.basename(ckpt_state.model_checkpoint_path)) print('Restore from {}'.format(model_path)) saver.restore(sess, model_path) print("Checkpoints step: {}".format(global_step.eval(session=sess))) images_path, labels, lexicons = get_data_lexicon(args) predicts = [] # for img_path, label in zip(images_path, labels): for i in tqdm.tqdm(range(len(images_path))): img_path = images_path[i] label = labels[i] try: img = cv2.imread(img_path) except Exception as e: print("{} error: {}".format(img_path, e)) continue img, la, width = data_preprocess(img, label, char2id, args) pred_value, attention_weights_value = sess.run([pred, attention_weights], feed_dict={input_images: [img], input_labels: [la], input_images_width: [width]}) pred_value_str = idx2label(pred_value, id2char, char2id)[0] # print("predict: {} label: {}".format(pred_value_str, label)) predicts.append(pred_value_str) if args.vis_dir != None and args.vis_dir != "": os.makedirs(args.vis_dir, exist_ok=True) os.makedirs(os.path.join(args.vis_dir, "errors"), exist_ok=True) _ = line_visualize(img, attention_weights_value, pred_value_str, args.vis_dir, img_path) if pred_value_str.lower() != label.lower(): _ = line_visualize(img, attention_weights_value, pred_value_str, os.path.join(args.vis_dir, "errors"), img_path) acc_rate = calc_metrics_lexicon(predicts, labels, lexicons) print("Done, Accuracy: {}".format(acc_rate))
def idx2label(inputs, id2char=None, char2id=None): if id2char is None: voc, char2id, id2char = get_vocabulary(voc_type="ALLCASES_SYMBOLS") def end_cut(ins): cut_ins = [] for id in ins: if id != char2id['EOS']: if id != char2id['UNK']: cut_ins.append(id2char[id]) else: break return cut_ins if isinstance(inputs, np.ndarray): assert len(inputs.shape) == 2, "input's rank should be 2" results = [''.join([ch for ch in end_cut(ins)]) for ins in inputs] return results else: print("input to idx2label should be numpy array") return inputs
def generator(image_dir, gt_path, input_height, input_width, batch_size, max_len, voc_type, keep_ratio=True, with_aug=True): if gt_path.split(".")[-1] == 'txt': data_loader = TextLoader(image_dir) elif gt_path.split(".")[-1] == 'json': data_loader = JSONLoader(image_dir) else: print("Unsupported gt file format") images_path, transcriptions = data_loader.parse_gt(gt_path) print("There are {} images in {}".format(len(images_path), image_dir)) assert voc_type in ['LOWERCASE', 'ALLCASES', 'ALLCASES_SYMBOLS'] index = np.arange(0, len(images_path)) voc, char2id, id2char = get_vocabulary(voc_type) # char2id = dict(zip(voc, range(len(voc)))) # id2char = dict(zip(range(len(voc)), voc)) is_lowercase = (voc_type == 'LOWERCASE') batch_images = [] batch_images_width = [] batch_labels = [] batch_lengths = [] batch_masks = [] batch_labels_str = [] while True: np.random.shuffle(index) for i in index: try: img = cv2.imread(images_path[i]) word = transcriptions[i] if is_lowercase: word = word.lower() if img is None: print("corrupt {}".format(images_path[i])) continue H, W, C = img.shape # Rotate the vertical images if H > 1.1 * W: img = np.rot90(img) H, W = W, H # Resize the images img_resize = np.zeros((input_height, input_width, C), dtype=np.uint8) # new_height = input_height # Data augmentation if with_aug: ratn = random.randint(0, 4) if ratn == 0: rand_reg = random.random() * 30 - 15 img = rotate_img(img, rand_reg) if keep_ratio: new_width = int((1.0 * H / input_height) * input_width) new_width = new_width if new_width < input_width else input_width new_width = new_width if new_width >= input_height else input_height new_height = input_height img = cv2.resize(img, (new_width, new_height)) img_resize[:new_height, :new_width, :] = img.copy() else: new_width = input_width img_resize = cv2.resize(img, (input_width, input_height)) # Process the labels label = np.full((max_len), char2id['PAD'], dtype=np.int) label_mask = np.full((max_len), 0, dtype=np.int) label_list = [] for char in word: if char in char2id: label_list.append(char2id[char]) else: label_list.append(char2id['UNK']) # label_list = label_list + [char2id['EOS']] # assert len(label_list) <= max_len if len(label_list) > (max_len - 1): label_list = label_list[:(max_len - 1)] label_list = label_list + [char2id['EOS']] label[:len(label_list)] = np.array(label_list) if label.shape[0] <= 0: continue label_len = len(label_list) label_mask[:label_len] = 1 batch_images.append(img_resize) batch_images_width.append(new_width) batch_labels.append(label) batch_masks.append(label_mask) batch_lengths.append(label_len) batch_labels_str.append(word) assert len(batch_images) == len(batch_labels) == len( batch_lengths) if len(batch_images) == batch_size: yield np.array(batch_images), np.array( batch_labels), np.array(batch_masks), np.array( batch_lengths), batch_labels_str, np.array( batch_images_width) batch_images = [] batch_images_width = [] batch_labels = [] batch_masks = [] batch_lengths = [] batch_labels_str = [] except Exception as e: print(e) print(images_path[i]) continue
def main_test_lmdb(args): voc, char2id, id2char = get_vocabulary(voc_type=args.voc_type) input_images = tf.placeholder(dtype=tf.float32, shape=[1, args.height, args.width, 3], name="input_images") input_images_width = tf.placeholder(dtype=tf.float32, shape=[1], name="input_images_width") input_labels = tf.placeholder(dtype=tf.int32, shape=[1, args.max_len], name="input_labels") sar_model = SARModel(num_classes=len(voc), encoder_dim=args.encoder_sdim, encoder_layer=args.encoder_layers, decoder_dim=args.decoder_sdim, decoder_layer=args.decoder_layers, decoder_embed_dim=args.decoder_edim, seq_len=args.max_len, is_training=False) # encoder_state, feature_map, mask_map = sar_model.inference(input_images, input_images_width, 1, reuse=True) # model_infer, attention_weights, pred = sar_model.decode(encoder_state, feature_map, input_labels, mask_map, reuse=True, decode_type=args.decode_type) model_infer, attention_weights, pred, _ = sar_model(input_images, input_labels, input_images_width, batch_size=1, reuse=False) global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), trainable=False, dtype=tf.int32) variable_averages = tf.train.ExponentialMovingAverage(0.997, global_step) saver = tf.train.Saver(variable_averages.variables_to_restore()) # saver = tf.train.Saver(tf.global_variables()) with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: ckpt_state = tf.train.get_checkpoint_state(args.checkpoints) model_path = os.path.join(args.checkpoints, os.path.basename(ckpt_state.model_checkpoint_path)) print('Restore from {}'.format(model_path)) saver.restore(sess, model_path) print("Checkpoints step: {}".format(global_step.eval(session=sess))) env = lmdb.open(args.test_data_dir, readonly=True) txn = env.begin() num_samples = int(txn.get(b"num-samples").decode()) predicts = [] labels = [] # for img_path, label in zip(images_path, labels): for i in tqdm.tqdm(range(1, num_samples+1)): image_key = b'image-%09d' % i label_key = b'label-%09d' % i imgbuf = txn.get(image_key) buf = six.BytesIO() buf.write(imgbuf) buf.seek(0) img_pil = Image.open(buf).convert('RGB') img = np.array(img_pil) label = txn.get(label_key).decode() labels.append(label) img, la, width = data_preprocess(img, label, char2id, args) pred_value, attention_weights_value = sess.run([pred, attention_weights], feed_dict={input_images: [img], input_labels: [la], input_images_width: [ width]}) pred_value_str = idx2label(pred_value, id2char, char2id)[0] # print("predict: {} label: {}".format(pred_value_str, label)) predicts.append(pred_value_str) img_vis = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) if args.vis_dir != None and args.vis_dir != "": os.makedirs(args.vis_dir, exist_ok=True) os.makedirs(os.path.join(args.vis_dir, "errors"), exist_ok=True) # _ = line_visualize(img, attention_weights_value, pred_value_str, args.vis_dir, "{}.jpg".format(i)) _ = heatmap_visualize(img_vis, attention_weights_value, pred_value_str, args.vis_dir, "{}.jpg".format(i)) # _ = mask_visualize(img, attention_weights_value, pred_value_str, args.vis_dir, img_path) if pred_value_str.lower() != label.lower(): _ = heatmap_visualize(img_vis, attention_weights_value, pred_value_str, os.path.join(args.vis_dir, "errors"), "{}.jpg".format(i)) acc_rate = calc_metrics(predicts, labels) print("Done, Accuracy: {}".format(acc_rate))
def main_train(args): voc, char2id, id2char = get_vocabulary(voc_type=args.voc_type) # Build graph input_train_images = tf.placeholder(dtype=tf.float32, shape=[None, args.height, args.width, 3], name="input_train_images") input_train_images_width = tf.placeholder(dtype=tf.float32, shape=[None], name="input_train_width") input_train_labels = tf.placeholder(dtype=tf.int32, shape=[None, args.max_len], name="input_train_labels") input_train_labels_mask = tf.placeholder(dtype=tf.int32, shape=[None, args.max_len], name="input_train_labels_mask") input_val_images = tf.placeholder(dtype=tf.float32, shape=[None, args.height, args.width, 3],name="input_val_images") input_val_images_width = tf.placeholder(dtype=tf.float32, shape=[None], name="input_val_width") input_val_labels = tf.placeholder(dtype=tf.int32, shape=[None, args.max_len], name="input_val_labels") input_val_labels_mask = tf.placeholder(dtype=tf.int32, shape=[None, args.max_len], name="input_val_labels_mask") sar_model = SARModel(num_classes=len(voc), encoder_dim=args.encoder_sdim, encoder_layer=args.encoder_layers, decoder_dim=args.decoder_sdim, decoder_layer=args.decoder_layers, decoder_embed_dim=args.decoder_edim, seq_len=args.max_len, is_training=True) sar_model_val = SARModel(num_classes=len(voc), encoder_dim=args.encoder_sdim, encoder_layer=args.encoder_layers, decoder_dim=args.decoder_sdim, decoder_layer=args.decoder_layers, decoder_embed_dim=args.decoder_edim, seq_len=args.max_len, is_training=False) train_model_infer, train_attention_weights, train_pred = sar_model(input_train_images, input_train_labels, input_train_images_width, reuse=False) train_loss = sar_model.loss(train_model_infer, input_train_labels, input_train_labels_mask) val_model_infer, val_attention_weights, val_pred = sar_model_val(input_val_images, input_val_labels, input_val_images_width, reuse=True) val_loss = sar_model_val.loss(val_model_infer, input_val_labels, input_val_labels_mask) global_step = tf.get_variable(name='global_step', initializer=tf.constant(0), trainable=False) learning_rate = tf.train.exponential_decay(learning_rate=args.lr, global_step=global_step, decay_steps=args.decay_iter, decay_rate=args.weight_decay, staircase=True) batch_norm_updates_op = tf.group(*tf.get_collection(tf.GraphKeys.UPDATE_OPS)) optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate) grads = optimizer.compute_gradients(train_loss) apply_gradient_op = optimizer.apply_gradients(grads, global_step=global_step) variable_averages = tf.train.ExponentialMovingAverage(0.997, global_step) variables_to_restore = variable_averages.variables_to_restore() #saver = tf.train.Saver(tf.global_variables(), max_to_keep=1) saver = tf.train.Saver(variables_to_restore) summary_writer = tf.summary.FileWriter(args.checkpoints) config = tf.ConfigProto() config.gpu_options.allow_growth = True with tf.Session(config=config) as sess: summary_writer.add_graph(sess.graph) start_iter = 0 if args.checkpoints != '': print('Restore model from {:s}'.format(args.checkpoints)) ckpt_state = tf.train.get_checkpoint_state(args.checkpoints) model_path = os.path.join(args.checkpoints, os.path.basename(ckpt_state.model_checkpoint_path)) saver.restore(sess=sess, save_path=model_path) start_iter = sess.run(tf.train.get_global_step()) # We use a built-in TF helper to export variables to constants output_graph_def = tf.graph_util.convert_variables_to_constants( sess, # The session is used to retrieve the weights tf.get_default_graph().as_graph_def(), # The graph_def is used to retrieve the nodes ['sar_1/ArgMax'] # The output node names are used to select the useful nodes ) frozen_model_path = os.path.join(args.checkpoints,os.path.basename(ckpt_state.model_checkpoint_path))+".pb" with tf.gfile.GFile(frozen_model_path, "wb") as f: f.write(output_graph_def.SerializeToString()) print("Frozen model saved at " + frozen_model_path)
def main_train(args): voc, char2id, id2char = get_vocabulary(voc_type=args.voc_type) tf.set_random_seed(1) # Build graph input_train_images = tf.placeholder( dtype=tf.float32, shape=[args.train_batch_size, args.height, args.width, 3], name="input_train_images") input_train_images_width = tf.placeholder(dtype=tf.float32, shape=[args.train_batch_size], name="input_train_width") input_train_labels = tf.placeholder( dtype=tf.int32, shape=[args.train_batch_size, args.max_len], name="input_train_labels") input_train_gauss_labels = tf.placeholder( dtype=tf.float32, shape=[args.train_batch_size, args.max_len, 6, 40], name="input_train_gauss_labels") # better way wanted!!! input_train_gauss_tags = tf.placeholder( dtype=tf.float32, shape=[args.train_batch_size, args.max_len], name="input_train_gauss_tags") input_train_gauss_params = tf.placeholder( dtype=tf.float32, shape=[args.train_batch_size, args.max_len, 4], name="input_train_gauss_params") input_train_labels_mask = tf.placeholder( dtype=tf.int32, shape=[args.train_batch_size, args.max_len], name="input_train_labels_mask") input_train_char_sizes = tf.placeholder( dtype=tf.float32, shape=[args.train_batch_size, args.max_len, 2], name="input_train_char_sizes") input_train_gauss_mask = tf.placeholder( dtype=tf.float32, shape=[args.train_batch_size, args.max_len, 6, 40], name="input_train_gauss_mask") input_val_images = tf.placeholder( dtype=tf.float32, shape=[args.val_batch_size, args.height, args.width, 3], name="input_val_images") input_val_images_width = tf.placeholder(dtype=tf.float32, shape=[args.val_batch_size], name="input_val_width") input_val_labels = tf.placeholder( dtype=tf.int32, shape=[args.val_batch_size, args.max_len], name="input_val_labels") # input_val_gauss_labels = tf.placeholder(dtype=tf.float32, shape=[args.val_batch_size, args.max_len, args.height, args.width], name="input_val_gauss_labels") input_val_labels_mask = tf.placeholder( dtype=tf.int32, shape=[args.val_batch_size, args.max_len], name="input_val_labels_mask") sar_model = SARModel(num_classes=len(voc), encoder_dim=args.encoder_sdim, encoder_layer=args.encoder_layers, decoder_dim=args.decoder_sdim, decoder_layer=args.decoder_layers, decoder_embed_dim=args.decoder_edim, seq_len=args.max_len, is_training=True, att_loss_type=args.att_loss_type, att_loss_weight=args.att_loss_weight) sar_model_val = SARModel(num_classes=len(voc), encoder_dim=args.encoder_sdim, encoder_layer=args.encoder_layers, decoder_dim=args.decoder_sdim, decoder_layer=args.decoder_layers, decoder_embed_dim=args.decoder_edim, seq_len=args.max_len, is_training=False) train_model_infer, train_attention_weights, train_pred, train_attention_params = sar_model( input_train_images, input_train_labels, input_train_images_width, batch_size=args.train_batch_size, reuse=False) if args.att_loss_type == "kldiv": train_loss, train_recog_loss, train_att_loss = sar_model.loss( train_model_infer, train_attention_weights, input_train_labels, input_train_gauss_labels, input_train_labels_mask, input_train_gauss_tags) elif args.att_loss_type == "l1" or args.att_loss_type == "l2": # train_loss, train_recog_loss, train_att_loss = sar_model.loss(train_model_infer, train_attention_params, input_train_labels, input_train_gauss_params, input_train_labels_mask, input_train_gauss_tags, input_train_char_sizes) train_loss, train_recog_loss, train_att_loss = sar_model.loss( train_model_infer, train_attention_weights, input_train_labels, input_train_gauss_labels, input_train_labels_mask, input_train_gauss_tags) elif args.att_loss_type == 'ce': train_loss, train_recog_loss, train_att_loss = sar_model.loss( train_model_infer, train_attention_weights, input_train_labels, input_train_gauss_labels, input_train_labels_mask, input_train_gauss_tags, input_train_gauss_mask) elif args.att_loss_type == 'gausskldiv': train_loss, train_recog_loss, train_att_loss = sar_model.loss( train_model_infer, train_attention_params, input_train_labels, input_train_gauss_params, input_train_labels_mask, input_train_gauss_tags) else: print("Unimplemented loss type {}".format(args.att_loss_dtype)) exit(-1) val_model_infer, val_attention_weights, val_pred, _ = sar_model_val( input_val_images, input_val_labels, input_val_images_width, batch_size=args.val_batch_size, reuse=True) train_data_list = get_data(args.train_data_dir, args.voc_type, args.max_len, args.height, args.width, args.train_batch_size, args.workers, args.keep_ratio, with_aug=args.aug) val_data_gen = evaluator_data.Evaluator(lmdb_data_dir=args.test_data_dir, batch_size=args.val_batch_size, height=args.height, width=args.width, max_len=args.max_len, keep_ratio=args.keep_ratio, voc_type=args.voc_type) val_data_gen.reset() global_step = tf.get_variable(name='global_step', initializer=tf.constant(0), trainable=False) learning_rate = tf.train.piecewise_constant(global_step, args.decay_bound, args.lr_stage) batch_norm_updates_op = tf.group(tf.get_collection( tf.GraphKeys.UPDATE_OPS)) # Save summary os.makedirs(args.checkpoints, exist_ok=True) tf.summary.scalar(name='train_loss', tensor=train_loss) tf.summary.scalar(name='train_recog_loss', tensor=train_recog_loss) tf.summary.scalar(name='train_att_loss', tensor=train_att_loss) # tf.summary.scalar(name='val_att_loss', tensor=val_att_loss) tf.summary.scalar(name='learning_rate', tensor=learning_rate) merge_summary_op = tf.summary.merge_all() train_start_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time())) model_name = 'sar_{:s}.ckpt'.format(str(train_start_time)) model_save_path = os.path.join(args.checkpoints, model_name) best_model_save_path = os.path.join(args.checkpoints, 'best_model', model_name) variable_averages = tf.train.ExponentialMovingAverage(0.997, global_step) variables_averages_op = variable_averages.apply(tf.trainable_variables()) with tf.control_dependencies( [variables_averages_op, batch_norm_updates_op]): optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate) grads = optimizer.compute_gradients(train_loss) if args.grad_clip > 0: print("With Gradients clipped!") for idx, (grad, var) in enumerate(grads): grads[idx] = (tf.clip_by_norm(grad, args.grad_clip), var) train_op = optimizer.apply_gradients(grads, global_step=global_step) saver = tf.train.Saver(tf.global_variables(), max_to_keep=1) best_saver = tf.train.Saver(tf.global_variables(), max_to_keep=1) summary_writer = tf.summary.FileWriter(args.checkpoints) config = tf.ConfigProto() config.gpu_options.allow_growth = True log_file = open(os.path.join(args.checkpoints, args.checkpoints + ".log"), "w") with tf.Session(config=config) as sess: summary_writer.add_graph(sess.graph) start_iter = 0 if args.resume == True and args.pretrained != '': print('Restore model from {:s}'.format(args.pretrained)) ckpt_state = tf.train.get_checkpoint_state(args.pretrained) model_path = os.path.join( args.pretrained, os.path.basename(ckpt_state.model_checkpoint_path)) saver.restore(sess=sess, save_path=model_path) start_iter = sess.run(tf.train.get_global_step()) elif args.resume == False and args.pretrained != '': print('Restore pretrained model from {:s}'.format(args.pretrained)) ckpt_state = tf.train.get_checkpoint_state(args.pretrained) model_path = os.path.join( args.pretrained, os.path.basename(ckpt_state.model_checkpoint_path)) saver.restore(sess=sess, save_path=model_path) sess.run(tf.assign(global_step, 0)) else: print('Training from scratch') init = tf.global_variables_initializer() sess.run(init) # Evaluate the model first val_pred_value_all = [] val_labels = [] for eval_iter in range(val_data_gen.num_samples // args.val_batch_size): val_data = val_data_gen.get_batch() if val_data is None: break print("Evaluation: [{} / {}]".format( eval_iter, (val_data_gen.num_samples // args.val_batch_size))) val_pred_value = sess.run(val_pred, feed_dict={ input_val_images: val_data[0], input_val_labels: val_data[1], input_val_images_width: val_data[5], input_val_labels_mask: val_data[2] }) val_pred_value_all.extend(val_pred_value) val_labels.extend(val_data[4]) val_data_gen.reset() val_metrics_result = calc_metrics(idx2label( np.array(val_pred_value_all)), val_labels, metrics_type="accuracy") print("Evaluation Before training: Test accuracy {:3f}".format( val_metrics_result)) val_best_acc = val_metrics_result while start_iter < args.iters: start_iter += 1 train_data = get_batch_data(train_data_list, args.train_batch_size) _, train_loss_value, train_recog_loss_value, train_att_loss_value, train_pred_value = sess.run( [ train_op, train_loss, train_recog_loss, train_att_loss, train_pred ], feed_dict={ input_train_images: train_data[0], input_train_labels: train_data[1], input_train_gauss_labels: train_data[2], input_train_gauss_params: train_data[7], input_train_labels_mask: train_data[3], input_train_images_width: train_data[5], input_train_gauss_tags: train_data[6], input_train_char_sizes: train_data[8], input_train_gauss_mask: train_data[9] }) if start_iter % args.log_iter == 0: print( "Iter {} train loss= {:3f} (recog loss= {:3f} att loss= {:3f})" .format(start_iter, train_loss_value, train_recog_loss_value, train_att_loss_value)) log_file.write( "Iter {} train loss= {:3f} (recog loss= {:3f} att loss= {:3f})" .format(start_iter, train_loss_value, train_recog_loss_value, train_att_loss_value)) if start_iter % args.summary_iter == 0: merge_summary_value = sess.run(merge_summary_op, feed_dict={ input_train_images: train_data[0], input_train_labels: train_data[1], input_train_gauss_labels: train_data[2], input_train_gauss_params: train_data[7], input_train_labels_mask: train_data[3], input_train_images_width: train_data[5], input_train_gauss_tags: train_data[6], input_train_char_sizes: train_data[8], input_train_gauss_mask: train_data[9] }) summary_writer.add_summary(summary=merge_summary_value, global_step=start_iter) if start_iter % args.eval_iter == 0: val_pred_value_all = [] val_labels = [] for eval_iter in range(val_data_gen.num_samples // args.val_batch_size): val_data = val_data_gen.get_batch() if val_data is None: break print("Evaluation: [{} / {}]".format( eval_iter, (val_data_gen.num_samples // args.val_batch_size))) val_pred_value = sess.run(val_pred, feed_dict={ input_val_images: val_data[0], input_val_labels: val_data[1], input_val_labels_mask: val_data[2], input_val_images_width: val_data[5] }) val_pred_value_all.extend(val_pred_value) val_labels.extend(val_data[4]) val_data_gen.reset() train_metrics_result = calc_metrics( idx2label(train_pred_value), train_data[4], metrics_type="accuracy") val_metrics_result = calc_metrics(idx2label( np.array(val_pred_value_all)), val_labels, metrics_type="accuracy") print( "Evaluation Iter {} train accuracy: {:3f} test accuracy {:3f}" .format(start_iter, train_metrics_result, val_metrics_result)) log_file.write( "Evaluation Iter {} train accuracy: {:3f} test accuracy {:3f}\n" .format(start_iter, train_metrics_result, val_metrics_result)) if val_metrics_result >= val_best_acc: print("Better results! Save checkpoitns to {}".format( best_model_save_path)) val_best_acc = val_metrics_result best_saver.save(sess, best_model_save_path, global_step=global_step) if start_iter % args.save_iter == 0: print("Iter {} save to checkpoint".format(start_iter)) saver.save(sess, model_save_path, global_step=global_step) log_file.close()
def generator(lmdb_dir, input_height, input_width, batch_size, max_len, voc_type, keep_ratio=True, with_aug=True): env = lmdb.open(lmdb_dir, max_readers=32, readonly=True) txn = env.begin() num_samples = int(txn.get(b"num-samples").decode()) print("There are {} images in {}".format(num_samples, lmdb_dir)) index = np.arange(0, num_samples) # TODO check index is reliable voc, char2id, id2char = get_vocabulary(voc_type) is_lowercase = (voc_type == 'LOWERCASE') batch_images = [] batch_images_width = [] batch_labels = [] batch_lengths = [] batch_masks = [] batch_labels_str = [] while True: np.random.shuffle(index) for i in index: i += 1 try: image_key = b'image-%09d' % i label_key = b'label-%09d' % i imgbuf = txn.get(image_key) buf = six.BytesIO() buf.write(imgbuf) buf.seek(0) img_pil = Image.open(buf).convert('RGB') img = np.array(img_pil) word = txn.get(label_key).decode() if is_lowercase: word = word.lower() H, W, C = img.shape # Rotate the vertical images if H > 1.1 * W: img = np.rot90(img) H, W = W, H # Resize the images img_resize = np.zeros((input_height, input_width, C), dtype=np.uint8) # Data augmentation if with_aug: ratn = random.randint(0, 4) if ratn == 0: rand_reg = random.random() * 30 - 15 img = rotate_img(img, rand_reg) if keep_ratio: new_width = int((1.0 * H / input_height) * input_width) new_width = new_width if new_width < input_width else input_width new_width = new_width if new_width >= input_height else input_height new_height = input_height img = cv2.resize(img, (new_width, new_height)) img_resize[:new_height, :new_width, :] = img.copy() else: new_width = input_width img_resize = cv2.resize(img, (input_width, input_height)) label = np.full((max_len), char2id['PAD'], dtype=np.int) label_mask = np.full((max_len), 0, dtype=np.int) label_list = [] for char in word: if char in char2id: label_list.append(char2id[char]) else: label_list.append(char2id['UNK']) if len(label_list) > (max_len - 1): label_list = label_list[:(max_len - 1)] label_list = label_list + [char2id['EOS']] label[:len(label_list)] = np.array(label_list) if label.shape[0] <= 0: continue label_len = len(label_list) label_mask[:label_len] = 1 batch_images.append(img_resize) batch_images_width.append(new_width) batch_labels.append(label) batch_masks.append(label_mask) batch_lengths.append(label_len) batch_labels_str.append(word) assert len(batch_images) == len(batch_labels) == len( batch_lengths) if len(batch_images) == batch_size: yield np.array(batch_images), np.array( batch_labels), np.array(batch_masks), np.array( batch_lengths), batch_labels_str, np.array( batch_images_width) batch_images = [] batch_images_width = [] batch_labels = [] batch_masks = [] batch_lengths = [] batch_labels_str = [] except Exception as e: print(e) print("Error in %d" % i) continue
def get_batch(self): voc, char2id, id2char = get_vocabulary(self.voc_type) is_lowercase = (self.voc_type == 'LOWERCASE') batch_images = [] batch_images_width = [] batch_labels = [] batch_lengths = [] batch_masks = [] batch_labels_str = [] if (self.index + self.batch_size - 1) > self.num_samples: self.reset() return None # for i in range(self.index, self.index + self.batch_size - 1): while len(batch_images) < self.batch_size: try: image_key = b'image-%09d' % self.index label_key = b'label-%09d' % self.index self.index += 1 imgbuf = self.txn.get(image_key) buf = six.BytesIO() buf.write(imgbuf) buf.seek(0) img_pil = Image.open(buf).convert('RGB') img = np.array(img_pil) word = self.txn.get(label_key).decode() if is_lowercase: word = word.lower() H, W, C = img.shape # Rotate the vertical images if H > 1.1 * W: img = np.rot90(img) H, W = W, H # Resize the images img_resize = np.zeros((self.height, self.width, C), dtype=np.uint8) if self.keep_ratio: new_width = int((1.0 * H / self.height) * self.width) new_width = new_width if new_width < self.width else self.width new_width = new_width if new_width >= self.height else self.height new_height = self.height img = cv2.resize(img, (new_width, new_height)) img_resize[:new_height, :new_width, :] = img.copy() else: new_width = self.width img_resize = cv2.resize(img, (self.width, self.height)) label = np.full((self.max_len), char2id['PAD'], dtype=np.int) label_mask = np.full((self.max_len), 0, dtype=np.int) label_list = [] for char in word: if char in char2id: label_list.append(char2id[char]) else: label_list.append(char2id['UNK']) if len(label_list) > (self.max_len - 1): label_list = label_list[:(self.max_len - 1)] label_list = label_list + [char2id['EOS']] label[:len(label_list)] = np.array(label_list) if label.shape[0] <= 0: continue label_len = len(label_list) label_mask[:label_len] = 1 batch_images.append(img_resize) batch_images_width.append(new_width) batch_labels.append(label) batch_masks.append(label_mask) batch_lengths.append(label_len) batch_labels_str.append(word) assert len(batch_images) == len(batch_labels) == len( batch_lengths) except Exception as e: print(e) print("Error in %d" % self.index) continue return np.array(batch_images), np.array(batch_labels), np.array( batch_masks), np.array(batch_lengths), batch_labels_str, np.array( batch_images_width)
def main_train(args): voc, char2id, id2char = get_vocabulary(voc_type=args.voc_type) # Build graph input_train_images = tf.placeholder( dtype=tf.float32, shape=[args.train_batch_size, args.height, args.width, 3], name="input_train_images") input_train_images_width = tf.placeholder(dtype=tf.float32, shape=[args.train_batch_size], name="input_train_width") input_train_labels = tf.placeholder( dtype=tf.int32, shape=[args.train_batch_size, args.max_len], name="input_train_labels") input_train_labels_mask = tf.placeholder( dtype=tf.int32, shape=[args.train_batch_size, args.max_len], name="input_train_labels_mask") input_val_images = tf.placeholder( dtype=tf.float32, shape=[args.val_batch_size, args.height, args.width, 3], name="input_val_images") input_val_images_width = tf.placeholder(dtype=tf.float32, shape=[args.val_batch_size], name="input_val_width") input_val_labels = tf.placeholder( dtype=tf.int32, shape=[args.val_batch_size, args.max_len], name="input_val_labels") input_val_labels_mask = tf.placeholder( dtype=tf.int32, shape=[args.val_batch_size, args.max_len], name="input_val_labels_mask") sar_model = SARModel(num_classes=len(voc), encoder_dim=args.encoder_sdim, encoder_layer=args.encoder_layers, decoder_dim=args.decoder_sdim, decoder_layer=args.decoder_layers, decoder_embed_dim=args.decoder_edim, seq_len=args.max_len, is_training=True) sar_model_val = SARModel(num_classes=len(voc), encoder_dim=args.encoder_sdim, encoder_layer=args.encoder_layers, decoder_dim=args.decoder_sdim, decoder_layer=args.decoder_layers, decoder_embed_dim=args.decoder_edim, seq_len=args.max_len, is_training=False) train_model_infer, train_attention_weights, train_pred = sar_model( input_train_images, input_train_labels, input_train_images_width, batch_size=args.train_batch_size, reuse=False) train_loss = sar_model.loss(train_model_infer, input_train_labels, input_train_labels_mask) val_model_infer, val_attention_weights, val_pred = sar_model_val( input_val_images, input_val_labels, input_val_images_width, batch_size=args.val_batch_size, reuse=True) val_loss = sar_model_val.loss(val_model_infer, input_val_labels, input_val_labels_mask) train_data_list = get_data(args.train_data_dir, args.train_data_gt, args.voc_type, args.max_len, args.num_train, args.height, args.width, args.train_batch_size, args.workers, args.keep_ratio, with_aug=args.aug) val_data_list = get_data(args.test_data_dir, args.test_data_gt, args.voc_type, args.max_len, args.num_train, args.height, args.width, args.val_batch_size, args.workers, args.keep_ratio, with_aug=False) global_step = tf.get_variable(name='global_step', initializer=tf.constant(0), trainable=False) learning_rate = tf.train.exponential_decay(learning_rate=args.lr, global_step=global_step, decay_steps=args.decay_iter, decay_rate=args.weight_decay, staircase=True) batch_norm_updates_op = tf.group( *tf.get_collection(tf.GraphKeys.UPDATE_OPS)) optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate) grads = optimizer.compute_gradients(train_loss) apply_gradient_op = optimizer.apply_gradients(grads, global_step=global_step) # Save summary os.makedirs(args.checkpoints, exist_ok=True) tf.summary.scalar(name='train_loss', tensor=train_loss) tf.summary.scalar(name='val_loss', tensor=val_loss) tf.summary.scalar(name='learning_rate', tensor=learning_rate) merge_summary_op = tf.summary.merge_all() train_start_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time())) model_name = 'sar_{:s}.ckpt'.format(str(train_start_time)) model_save_path = os.path.join(args.checkpoints, model_name) variable_averages = tf.train.ExponentialMovingAverage(0.997, global_step) variables_averages_op = variable_averages.apply(tf.trainable_variables()) with tf.control_dependencies( [variables_averages_op, apply_gradient_op, batch_norm_updates_op]): train_op = tf.no_op(name='train_op') saver = tf.train.Saver(tf.global_variables(), max_to_keep=1) summary_writer = tf.summary.FileWriter(args.checkpoints) config = tf.ConfigProto() config.gpu_options.allow_growth = True with tf.Session(config=config) as sess: summary_writer.add_graph(sess.graph) start_iter = 0 if args.resume == True and args.pretrained != '': print('Restore model from {:s}'.format(args.pretrained)) ckpt_state = tf.train.get_checkpoint_state(args.pretrained) model_path = os.path.join( args.pretrained, os.path.basename(ckpt_state.model_checkpoint_path)) saver.restore(sess=sess, save_path=model_path) start_iter = sess.run(tf.train.get_global_step()) else: print('Training from scratch') init = tf.global_variables_initializer() sess.run(init) while start_iter < args.iters: start_iter += 1 train_data = get_batch_data(train_data_list, args.train_batch_size) _, train_loss_value, train_pred_value = sess.run( [train_op, train_loss, train_pred], feed_dict={ input_train_images: train_data[0], input_train_labels: train_data[1], input_train_labels_mask: train_data[2], input_train_images_width: train_data[4] }) if start_iter % args.log_iter == 0: print("Iter {} train loss= {:3f}".format( start_iter, train_loss_value)) if start_iter % args.summary_iter == 0: val_data = get_batch_data(val_data_list, args.val_batch_size) merge_summary_value, val_pred_value, val_loss_value = sess.run( [merge_summary_op, val_pred, val_loss], feed_dict={ input_train_images: train_data[0], input_train_labels: train_data[1], input_train_labels_mask: train_data[2], input_train_images_width: train_data[4], input_val_images: val_data[0], input_val_labels: val_data[1], input_val_labels_mask: val_data[2], input_val_images_width: val_data[4] }) summary_writer.add_summary(summary=merge_summary_value, global_step=start_iter) if start_iter % args.eval_iter == 0: print("#" * 80) print("train prediction \t train labels ") for result, gt in zip(idx2label(train_pred_value), train_data[3]): print("{} \t {}".format(result, gt)) print("#" * 80) print("test prediction \t test labels ") for result, gt in zip( idx2label(val_pred_value)[:32], val_data[3][:32]): print("{} \t {}".format(result, gt)) print("#" * 80) train_metrics_result = calc_metrics( idx2label(train_pred_value), train_data[3], metrics_type="accuracy") val_metrics_result = calc_metrics( idx2label(val_pred_value), val_data[3], metrics_type="accuracy") print( "Evaluation Iter {} test loss: {:3f} train accuracy: {:3f} test accuracy {:3f}" .format(start_iter, val_loss_value, train_metrics_result, val_metrics_result)) if start_iter % args.save_iter == 0: print("Iter {} save to checkpoint".format(start_iter)) saver.save(sess, model_save_path, global_step=global_step)
def generator(lmdb_dir, input_height, input_width, batch_size, max_len, voc_type, keep_ratio=True, with_aug=True): env = lmdb.open(lmdb_dir, max_readers=32, readonly=True) txn = env.begin() if txn.get(b"num-samples") is None: # SynthText800K is still generating num_samples = 3000000 else: num_samples = int(txn.get(b"num-samples").decode()) print("There are {} images in {}".format(num_samples, lmdb_dir)) index = np.arange(0, num_samples) # TODO check index is reliable voc, char2id, id2char = get_vocabulary(voc_type) is_lowercase = (voc_type == 'LOWERCASE') batch_images = [] batch_images_width = [] batch_labels = [] batch_gausses = [] batch_params = [] batch_gauss_tags = [] batch_gauss_masks = [] batch_lengths = [] batch_masks = [] batch_labels_str = [] batch_char_size = [] batch_char_bbs = [] while True: np.random.shuffle(index) for i in index: i += 1 try: image_key = b'image-%09d' % i label_key = b'label-%09d' % i char_key = b'char-%09d' % i imgbuf = txn.get(image_key) buf = six.BytesIO() buf.write(imgbuf) buf.seek(0) img_pil = Image.open(buf).convert('RGB') img = np.array(img_pil) word = txn.get(label_key).decode() wo_char = (txn.get(char_key) is None) if wo_char: charBBs = np.zeros(dtype=np.float32, shape=[len(word), 4, 2]) else: charBBs = np.array([ math.ceil(float(c)) for c in txn.get(char_key).decode().split() ]).reshape([-1, 4, 2]).astype(np.float32) num_char_bb = charBBs.shape[0] if is_lowercase: word = word.lower() H, W, C = img.shape charBBs[:, :, 0] = np.clip(charBBs[:, :, 0], 0, W) charBBs[:, :, 1] = np.clip(charBBs[:, :, 1], 0, H) # Rotate the vertical images if H > 1.1 * W: img = np.rot90(img) H, W = W, H charBBs = charBBs[:, :, ::-1] charBBs[:, :, 1] = H - charBBs[:, :, 1] # Resize the images img_resize = np.zeros((input_height, input_width, C), dtype=np.uint8) # Data augmentation if with_aug: blur_aug = imgaug.augmenters.blur.GaussianBlur(sigma=(0, 1.0)) contrast_aug = imgaug.augmenters.contrast.LinearContrast( (0.75, 1.0)) affine_aug = imgaug.augmenters.geometric.PiecewiseAffine( scale=(0.01, 0.02), mode='constant', cval=0) # Gaussian Blur ratn = random.randint(0, 1) # img = cv2.GaussianBlur(img, (3, 3), 2) if ratn == 0: img = blur_aug.augment_image(img) # Gaussian noise # img = adding_guass(img) # Contrast ratn = random.randint(0, 1) if ratn == 0: img = contrast_aug.augment_image(img) # Affine ratn = random.randint(0, 1) if ratn == 0: img = affine_aug.augment_image(img) # Rotation ratn = random.randint(0, 1) if ratn == 0: # rand_reg = random.random() * 30 - 15 rand_reg = random.randint(-15, 15) img, charBBs, (W, H) = rotate_img(img, rand_reg, BBs=charBBs) if keep_ratio: new_width = int((1.0 * H / input_height) * input_width) new_width = new_width if new_width < input_width else input_width new_width = new_width if new_width >= input_height else input_height new_height = input_height img = cv2.resize(img, (new_width, new_height)) img_resize[:new_height, :new_width, :] = img.copy() else: new_width = input_width new_height = input_height img_resize = cv2.resize(img, (input_width, input_height)) ratio_w = float(new_width) / float(W) ratio_h = float(new_height) / float(H) charBBs[:, :, 0] = charBBs[:, :, 0] * ratio_w charBBs[:, :, 1] = charBBs[:, :, 1] * ratio_h # visualization for debugging rotate augmentation # img_debug = img_resize.copy() # for bb in charBBs: # img_debug = cv2.polylines(img_debug, [bb.astype(np.int32).reshape((-1, 1, 2))], True, # color=(255, 255, 0), thickness=1) # cv2.imwrite("./char_bb_vis/{}.jpg".format(i), img_debug) # feature_map_w = 40. # feature_map_h = 6. # feature_map_w = input_width # feature_map_h = input_height # ratio_w_f = feature_map_w / input_width # ratio_h_f = feature_map_h / input_height # charBBs[:, :, 0] = charBBs[:, :, 0] * ratio_w_f # charBBs[:, :, 1] = charBBs[:, :, 1] * ratio_h_f label = np.full((max_len), char2id['PAD'], dtype=np.int) label_mask = np.full((max_len), 0, dtype=np.int) label_list = [] for char in word: if char in char2id: label_list.append(char2id[char]) else: label_list.append(char2id['UNK']) if len(label_list) > (max_len - 1): label_list = label_list[:(max_len - 1)] num_char_bb = max_len - 1 label_list = label_list + [char2id['EOS']] label[:len(label_list)] = np.array(label_list) if label.shape[0] <= 0: continue label_len = len(label_list) label_mask[:label_len] = 1 if (label_len - 1) != num_char_bb: print( "Unmatched between char bb and label length in index {}" .format(i)) print( "Information: label: {} label_length: {} num_char: {}". format(word, label_len - 1, num_char_bb)) continue # Get gaussian distribution labels # gauss_labels = np.zeros(dtype=np.float32, shape=[max_len, int(feature_map_h), int(feature_map_w)]) # T * H * W gauss_labels = np.zeros( dtype=np.float32, shape=[max_len, input_height, input_width]) # T * H * W # gauss_mask = np.zeros(dtype=np.float32, shape=[max_len, int(feature_map_h), int(feature_map_w)]) # T * H * W gauss_mask = np.zeros( dtype=np.float32, shape=[max_len, input_height, input_width]) # T * H * W # distrib_params = [] distrib_params = np.zeros(dtype=np.float32, shape=[max_len, 4]) distrib_params[:, 2:] = 1. char_size = np.ones(dtype=np.float32, shape=[max_len, 2]) gauass_tags = [0] * max_len charBBs = charBBs[:num_char_bb] if wo_char == False: for i, BB in enumerate(charBBs): # 4 * 2 # try: # Here we use min bounding rectangles min_rec, delta_x, delta_y = find_min_rectangle( BB) # 4 * 2 if delta_x < 2 or delta_y < 2: param = get_distrib_params(BB) distrib_params[i] = param continue gauss_distrib = get_gauss_distrib( (delta_y, delta_x)) # delta_y * delta_x # param = get_distrib_params(BB) param = estim_gauss_params(gauss_distrib, delta_x, delta_y) # param[0] = param[0] / feature_map_w # param[1] = param[1] / feature_map_h # param[2] = param[2] / (0.25 * feature_map_w * feature_map_w) # param[3] = param[3] / (0.25 * feature_map_h * feature_map_h) # gauss_distrib = construct_gauss_distirb(param, delta_x, delta_y) char_size[i][0] = delta_x char_size[i][1] = delta_y distrib_params[i] = param # res_gauss = aff_gaussian(gauss_distrib, min_rec, BB, delta_x, delta_y) # delta_y * delta_x res_gauss = gauss_distrib gauass_tags[i] = 1. if np.max(res_gauss) > 0.: start_x, start_y = int(min_rec[0][0]), int( min_rec[0][1]) end_x, end_y = start_x + delta_x, start_y + delta_y end_x = end_x if end_x <= input_width else input_width end_y = end_y if end_y <= input_height else input_height gauss_labels[i, start_y:end_y, start_x:end_x] = res_gauss ex_start_x = math.floor(start_x - 0.3 * delta_x) ex_end_x = math.ceil(end_x + 0.3 * delta_x) ex_start_y = math.floor(start_y - 0.3 * delta_y) ex_end_y = math.ceil(end_y + 0.3 * delta_y) ex_start_x = ex_start_x if ex_start_x >= 0 else 0 ex_start_y = ex_start_y if ex_start_y >= 0 else 0 ex_end_x = ex_end_x if ex_end_x <= input_width else input_width ex_end_y = ex_end_y if ex_end_y <= input_height else input_height gauss_mask[i, int(ex_start_y):int(ex_end_y), int(ex_start_x):int(ex_end_x)] = 1. # gauss_mask[i, int(start_y):int(end_y), int(start_x):int(end_x)] = 1. # except Exception as e: # print(e) gauss_labels = sum_norm( gauss_labels.reshape( [gauss_labels.shape[0], -1])).reshape([-1, input_height, input_width]) # Reduce to feature map size gauss_labels = roi_sum(gauss_labels, target_h=6, target_w=40) gauss_mask = roi_max(gauss_mask, target_h=6, target_w=40) # distrib_params = np.array(distrib_params) batch_images.append(img_resize) batch_images_width.append(new_width) batch_labels.append(label) batch_gausses.append(gauss_labels) batch_params.append(distrib_params) batch_gauss_tags.append(gauass_tags) batch_gauss_masks.append(gauss_mask) batch_masks.append(label_mask) batch_lengths.append(label_len) batch_labels_str.append(word) batch_char_size.append(char_size) batch_char_bbs.append(charBBs) assert len(batch_images) == len(batch_labels) == len( batch_lengths) == len(batch_gausses) == len( batch_char_size) if len(batch_images) == batch_size: yield np.array(batch_images), \ np.array(batch_labels), \ np.array(batch_gausses), \ np.array(batch_masks), \ np.array(batch_lengths), \ batch_labels_str, \ np.array(batch_images_width), \ np.array(batch_gauss_tags), \ np.array(batch_params), \ np.array(batch_char_size), \ np.array(batch_gauss_masks) batch_images = [] batch_images_width = [] batch_labels = [] batch_gausses = [] batch_params = [] batch_gauss_tags = [] batch_gauss_masks = [] batch_masks = [] batch_lengths = [] batch_labels_str = [] batch_char_size = [] except Exception as e: print(e) print("Error in %d" % i) continue