Пример #1
0
def main_test(args):
    voc, char2id, id2char = get_vocabulary(voc_type=args.voc_type)

    input_images = tf.placeholder(dtype=tf.float32, shape=[1, args.height, None, 3], name="input_images")
    input_images_width = tf.placeholder(dtype=tf.float32, shape=[1], name="input_images_width")
    input_labels = tf.placeholder(dtype=tf.int32, shape=[1, args.max_len], name="input_labels")
    sar_model = SARModel(num_classes=len(voc),
                         encoder_dim=args.encoder_sdim,
                         encoder_layer=args.encoder_layers,
                         decoder_dim=args.decoder_sdim,
                         decoder_layer=args.decoder_layers,
                         decoder_embed_dim=args.decoder_edim,
                         seq_len=args.max_len,
                         is_training=False)

    model_infer, attention_weights, pred = sar_model(input_images, input_labels, input_images_width, batch_size=1, reuse=False)
    global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), trainable=False, dtype=tf.int32)
    variable_averages = tf.train.ExponentialMovingAverage(0.997, global_step)
    saver = tf.train.Saver(variable_averages.variables_to_restore())
    
    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        ckpt_state = tf.train.get_checkpoint_state(args.checkpoints)
        model_path = os.path.join(args.checkpoints, os.path.basename(ckpt_state.model_checkpoint_path))
        print('Restore from {}'.format(model_path))
        saver.restore(sess, model_path)

        images_path, labels = get_data(args)
        predicts = []
        for img_path, label in zip(images_path, labels):
            try:
                img = cv2.imread(img_path)
            except Exception as e:
                print("{} error: {}".format(img_path, e))
                continue

            img, la, width = data_preprocess(img, label, char2id, args)

            pred_value, attention_weights_value = sess.run([pred, attention_weights], feed_dict={input_images: [img],
                                                                                                 input_labels: [la],
                                                                                                 input_images_width: [width]})
            pred_value_str = idx2label(pred_value, id2char, char2id)[0]
            print("predict: {} label: {}".format(pred_value_str, label))
            predicts.append(pred_value_str)
            pred_value_str += '$'
            if args.vis_dir != None and args.vis_dir != "":
                os.makedirs(args.vis_dir, exist_ok=True)
                assert len(img.shape) == 3
                att_maps = attention_weights_value.reshape([-1, attention_weights_value.shape[2], attention_weights_value.shape[3], 1]) # T * H * W * 1
                for i, att_map in enumerate(att_maps):
                    if i >= len(pred_value_str):
                        break
                    att_map = cv2.resize(att_map, (img.shape[1], img.shape[0]))
                    _att_map = np.zeros(dtype=np.uint8, shape=[img.shape[0], img.shape[1], 3])
                    _att_map[:, :, -1] = (att_map * 255).astype(np.uint8)

                    show_attention = cv2.addWeighted(img, 0.5, _att_map, 2, 0)
                    cv2.imwrite(os.path.join(args.vis_dir, os.path.basename(img_path).split('.')[0] + "_" + str(i) + "_" + pred_value_str[i] + ".jpg"), show_attention)

    acc_rate = accuracy(predicts, labels)
    print("Done, Accuracy: {}".format(acc_rate))
Пример #2
0
def main_test_with_lexicon(args):
    voc, char2id, id2char = get_vocabulary(voc_type=args.voc_type)

    input_images = tf.placeholder(dtype=tf.float32, shape=[1, args.height, args.width, 3], name="input_images")
    input_images_width = tf.placeholder(dtype=tf.float32, shape=[1], name="input_images_width")
    input_labels = tf.placeholder(dtype=tf.int32, shape=[1, args.max_len], name="input_labels")
    sar_model = SARModel(num_classes=len(voc),
                         encoder_dim=args.encoder_sdim,
                         encoder_layer=args.encoder_layers,
                         decoder_dim=args.decoder_sdim,
                         decoder_layer=args.decoder_layers,
                         decoder_embed_dim=args.decoder_edim,
                         seq_len=args.max_len,
                         is_training=False)

    # encoder_state, feature_map, mask_map = sar_model.inference(input_images, input_images_width, 1, reuse=True)
    # model_infer, attention_weights, pred = sar_model.decode(encoder_state, feature_map, input_labels, mask_map, reuse=True, decode_type=args.decode_type)
    model_infer, attention_weights, pred, _ = sar_model(input_images, input_labels, input_images_width, batch_size=1, reuse=False)
    global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), trainable=False, dtype=tf.int32)
    variable_averages = tf.train.ExponentialMovingAverage(0.997, global_step)
    saver = tf.train.Saver(variable_averages.variables_to_restore())
    # saver = tf.train.Saver(tf.global_variables())
    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        ckpt_state = tf.train.get_checkpoint_state(args.checkpoints)
        model_path = os.path.join(args.checkpoints, os.path.basename(ckpt_state.model_checkpoint_path))
        print('Restore from {}'.format(model_path))
        saver.restore(sess, model_path)
        print("Checkpoints step: {}".format(global_step.eval(session=sess)))
        images_path, labels, lexicons = get_data_lexicon(args)
        predicts = []
        # for img_path, label in zip(images_path, labels):
        for i in tqdm.tqdm(range(len(images_path))):
            img_path = images_path[i]
            label = labels[i]
            try:
                img = cv2.imread(img_path)
            except Exception as e:
                print("{} error: {}".format(img_path, e))
                continue

            img, la, width = data_preprocess(img, label, char2id, args)

            pred_value, attention_weights_value = sess.run([pred, attention_weights], feed_dict={input_images: [img],
                                                                                                 input_labels: [la],
                                                                                                 input_images_width: [width]})
            pred_value_str = idx2label(pred_value, id2char, char2id)[0]
            # print("predict: {} label: {}".format(pred_value_str, label))
            predicts.append(pred_value_str)
            if args.vis_dir != None and args.vis_dir != "":
                os.makedirs(args.vis_dir, exist_ok=True)
                os.makedirs(os.path.join(args.vis_dir, "errors"), exist_ok=True)
                _ = line_visualize(img, attention_weights_value, pred_value_str, args.vis_dir, img_path)
                if pred_value_str.lower() != label.lower():
                    _ = line_visualize(img, attention_weights_value, pred_value_str, os.path.join(args.vis_dir, "errors"), img_path)
    acc_rate = calc_metrics_lexicon(predicts, labels, lexicons)
    print("Done, Accuracy: {}".format(acc_rate))
Пример #3
0
def idx2label(inputs, id2char=None, char2id=None):

    if id2char is None:
        voc, char2id, id2char = get_vocabulary(voc_type="ALLCASES_SYMBOLS")

    def end_cut(ins):
        cut_ins = []
        for id in ins:
            if id != char2id['EOS']:
                if id != char2id['UNK']:
                    cut_ins.append(id2char[id])
            else:
                break
        return cut_ins

    if isinstance(inputs, np.ndarray):
        assert len(inputs.shape) == 2, "input's rank should be 2"
        results = [''.join([ch for ch in end_cut(ins)]) for ins in inputs]
        return results
    else:
        print("input to idx2label should be numpy array")
        return inputs
Пример #4
0
def generator(image_dir,
              gt_path,
              input_height,
              input_width,
              batch_size,
              max_len,
              voc_type,
              keep_ratio=True,
              with_aug=True):
    if gt_path.split(".")[-1] == 'txt':
        data_loader = TextLoader(image_dir)
    elif gt_path.split(".")[-1] == 'json':
        data_loader = JSONLoader(image_dir)
    else:
        print("Unsupported gt file format")
    images_path, transcriptions = data_loader.parse_gt(gt_path)
    print("There are {} images in {}".format(len(images_path), image_dir))

    assert voc_type in ['LOWERCASE', 'ALLCASES', 'ALLCASES_SYMBOLS']

    index = np.arange(0, len(images_path))
    voc, char2id, id2char = get_vocabulary(voc_type)
    # char2id = dict(zip(voc, range(len(voc))))
    # id2char = dict(zip(range(len(voc)), voc))
    is_lowercase = (voc_type == 'LOWERCASE')

    batch_images = []
    batch_images_width = []
    batch_labels = []
    batch_lengths = []
    batch_masks = []
    batch_labels_str = []

    while True:
        np.random.shuffle(index)
        for i in index:
            try:
                img = cv2.imread(images_path[i])
                word = transcriptions[i]
                if is_lowercase:
                    word = word.lower()

                if img is None:
                    print("corrupt {}".format(images_path[i]))
                    continue
                H, W, C = img.shape

                # Rotate the vertical images
                if H > 1.1 * W:
                    img = np.rot90(img)
                    H, W = W, H

                # Resize the images
                img_resize = np.zeros((input_height, input_width, C),
                                      dtype=np.uint8)
                # new_height = input_height

                # Data augmentation
                if with_aug:
                    ratn = random.randint(0, 4)
                    if ratn == 0:
                        rand_reg = random.random() * 30 - 15
                        img = rotate_img(img, rand_reg)

                if keep_ratio:
                    new_width = int((1.0 * H / input_height) * input_width)
                    new_width = new_width if new_width < input_width else input_width
                    new_width = new_width if new_width >= input_height else input_height
                    new_height = input_height
                    img = cv2.resize(img, (new_width, new_height))
                    img_resize[:new_height, :new_width, :] = img.copy()
                else:
                    new_width = input_width
                    img_resize = cv2.resize(img, (input_width, input_height))

                # Process the labels
                label = np.full((max_len), char2id['PAD'], dtype=np.int)
                label_mask = np.full((max_len), 0, dtype=np.int)
                label_list = []
                for char in word:
                    if char in char2id:
                        label_list.append(char2id[char])
                    else:
                        label_list.append(char2id['UNK'])
                # label_list = label_list + [char2id['EOS']]
                # assert len(label_list) <= max_len
                if len(label_list) > (max_len - 1):
                    label_list = label_list[:(max_len - 1)]
                label_list = label_list + [char2id['EOS']]
                label[:len(label_list)] = np.array(label_list)

                if label.shape[0] <= 0:
                    continue

                label_len = len(label_list)
                label_mask[:label_len] = 1
                batch_images.append(img_resize)
                batch_images_width.append(new_width)
                batch_labels.append(label)
                batch_masks.append(label_mask)
                batch_lengths.append(label_len)
                batch_labels_str.append(word)

                assert len(batch_images) == len(batch_labels) == len(
                    batch_lengths)

                if len(batch_images) == batch_size:
                    yield np.array(batch_images), np.array(
                        batch_labels), np.array(batch_masks), np.array(
                            batch_lengths), batch_labels_str, np.array(
                                batch_images_width)
                    batch_images = []
                    batch_images_width = []
                    batch_labels = []
                    batch_masks = []
                    batch_lengths = []
                    batch_labels_str = []
            except Exception as e:
                print(e)
                print(images_path[i])
                continue
Пример #5
0
def main_test_lmdb(args):
    voc, char2id, id2char = get_vocabulary(voc_type=args.voc_type)

    input_images = tf.placeholder(dtype=tf.float32, shape=[1, args.height, args.width, 3], name="input_images")
    input_images_width = tf.placeholder(dtype=tf.float32, shape=[1], name="input_images_width")
    input_labels = tf.placeholder(dtype=tf.int32, shape=[1, args.max_len], name="input_labels")
    sar_model = SARModel(num_classes=len(voc),
                         encoder_dim=args.encoder_sdim,
                         encoder_layer=args.encoder_layers,
                         decoder_dim=args.decoder_sdim,
                         decoder_layer=args.decoder_layers,
                         decoder_embed_dim=args.decoder_edim,
                         seq_len=args.max_len,
                         is_training=False)

    # encoder_state, feature_map, mask_map = sar_model.inference(input_images, input_images_width, 1, reuse=True)
    # model_infer, attention_weights, pred = sar_model.decode(encoder_state, feature_map, input_labels, mask_map, reuse=True, decode_type=args.decode_type)
    model_infer, attention_weights, pred, _ = sar_model(input_images, input_labels, input_images_width, batch_size=1, reuse=False)
    global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), trainable=False, dtype=tf.int32)
    variable_averages = tf.train.ExponentialMovingAverage(0.997, global_step)
    saver = tf.train.Saver(variable_averages.variables_to_restore())
    # saver = tf.train.Saver(tf.global_variables())
    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        ckpt_state = tf.train.get_checkpoint_state(args.checkpoints)
        model_path = os.path.join(args.checkpoints, os.path.basename(ckpt_state.model_checkpoint_path))
        print('Restore from {}'.format(model_path))
        saver.restore(sess, model_path)
        print("Checkpoints step: {}".format(global_step.eval(session=sess)))

        env = lmdb.open(args.test_data_dir, readonly=True)
        txn = env.begin()
        num_samples = int(txn.get(b"num-samples").decode())
        predicts = []
        labels = []
        # for img_path, label in zip(images_path, labels):
        for i in tqdm.tqdm(range(1, num_samples+1)):
            image_key = b'image-%09d' % i
            label_key = b'label-%09d' % i

            imgbuf = txn.get(image_key)
            buf = six.BytesIO()
            buf.write(imgbuf)
            buf.seek(0)

            img_pil = Image.open(buf).convert('RGB')
            img = np.array(img_pil)
            label = txn.get(label_key).decode()
            labels.append(label)
            img, la, width = data_preprocess(img, label, char2id, args)

            pred_value, attention_weights_value = sess.run([pred, attention_weights], feed_dict={input_images: [img],
                                                                                                 input_labels: [la],
                                                                                                 input_images_width: [
                                                                                                     width]})
            pred_value_str = idx2label(pred_value, id2char, char2id)[0]
            # print("predict: {} label: {}".format(pred_value_str, label))
            predicts.append(pred_value_str)
            img_vis = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
            if args.vis_dir != None and args.vis_dir != "":
                os.makedirs(args.vis_dir, exist_ok=True)
                os.makedirs(os.path.join(args.vis_dir, "errors"), exist_ok=True)
                # _ = line_visualize(img, attention_weights_value, pred_value_str, args.vis_dir, "{}.jpg".format(i))
                _ = heatmap_visualize(img_vis, attention_weights_value, pred_value_str, args.vis_dir, "{}.jpg".format(i))
                # _ = mask_visualize(img, attention_weights_value, pred_value_str, args.vis_dir, img_path)
                if pred_value_str.lower() != label.lower():
                    _ = heatmap_visualize(img_vis, attention_weights_value, pred_value_str, os.path.join(args.vis_dir, "errors"), "{}.jpg".format(i))
        acc_rate = calc_metrics(predicts, labels)
        print("Done, Accuracy: {}".format(acc_rate))
Пример #6
0
def main_train(args):

    voc, char2id, id2char = get_vocabulary(voc_type=args.voc_type)

    # Build graph
    input_train_images = tf.placeholder(dtype=tf.float32, shape=[None, args.height, args.width, 3], name="input_train_images")
    input_train_images_width = tf.placeholder(dtype=tf.float32, shape=[None], name="input_train_width")
    input_train_labels = tf.placeholder(dtype=tf.int32, shape=[None, args.max_len], name="input_train_labels")
    input_train_labels_mask = tf.placeholder(dtype=tf.int32, shape=[None, args.max_len], name="input_train_labels_mask")

    input_val_images = tf.placeholder(dtype=tf.float32, shape=[None, args.height, args.width, 3],name="input_val_images")
    input_val_images_width = tf.placeholder(dtype=tf.float32, shape=[None], name="input_val_width")
    input_val_labels = tf.placeholder(dtype=tf.int32, shape=[None, args.max_len], name="input_val_labels")
    input_val_labels_mask = tf.placeholder(dtype=tf.int32, shape=[None, args.max_len], name="input_val_labels_mask")

    sar_model = SARModel(num_classes=len(voc),
                        encoder_dim=args.encoder_sdim,
                        encoder_layer=args.encoder_layers,
                        decoder_dim=args.decoder_sdim,
                        decoder_layer=args.decoder_layers,
                        decoder_embed_dim=args.decoder_edim,
                        seq_len=args.max_len,
                        is_training=True)
    sar_model_val = SARModel(num_classes=len(voc),
                         encoder_dim=args.encoder_sdim,
                         encoder_layer=args.encoder_layers,
                         decoder_dim=args.decoder_sdim,
                         decoder_layer=args.decoder_layers,
                         decoder_embed_dim=args.decoder_edim,
                         seq_len=args.max_len,
                         is_training=False)
    train_model_infer, train_attention_weights, train_pred = sar_model(input_train_images, input_train_labels,
                                                                       input_train_images_width,
                                                                       reuse=False)
    train_loss = sar_model.loss(train_model_infer, input_train_labels, input_train_labels_mask)

    val_model_infer, val_attention_weights, val_pred = sar_model_val(input_val_images, input_val_labels,
                                                                       input_val_images_width,
                                                                       reuse=True)
    val_loss = sar_model_val.loss(val_model_infer, input_val_labels, input_val_labels_mask)

    global_step = tf.get_variable(name='global_step', initializer=tf.constant(0), trainable=False)
    learning_rate = tf.train.exponential_decay(learning_rate=args.lr,
                                               global_step=global_step,
                                               decay_steps=args.decay_iter,
                                               decay_rate=args.weight_decay,
                                               staircase=True)

    batch_norm_updates_op = tf.group(*tf.get_collection(tf.GraphKeys.UPDATE_OPS))
    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
    grads = optimizer.compute_gradients(train_loss)
    apply_gradient_op = optimizer.apply_gradients(grads, global_step=global_step)



    variable_averages = tf.train.ExponentialMovingAverage(0.997, global_step)
    variables_to_restore = variable_averages.variables_to_restore()


    #saver = tf.train.Saver(tf.global_variables(), max_to_keep=1)
    saver = tf.train.Saver(variables_to_restore)
    summary_writer = tf.summary.FileWriter(args.checkpoints)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    with tf.Session(config=config) as sess:
        summary_writer.add_graph(sess.graph)
        start_iter = 0
        if args.checkpoints != '':
            print('Restore model from {:s}'.format(args.checkpoints))
            ckpt_state = tf.train.get_checkpoint_state(args.checkpoints)
            model_path = os.path.join(args.checkpoints, os.path.basename(ckpt_state.model_checkpoint_path))
            saver.restore(sess=sess, save_path=model_path)
            start_iter = sess.run(tf.train.get_global_step())

            # We use a built-in TF helper to export variables to constants
            output_graph_def = tf.graph_util.convert_variables_to_constants(
                sess,  # The session is used to retrieve the weights
                tf.get_default_graph().as_graph_def(),  # The graph_def is used to retrieve the nodes
                ['sar_1/ArgMax']  # The output node names are used to select the useful nodes
            )

            frozen_model_path = os.path.join(args.checkpoints,os.path.basename(ckpt_state.model_checkpoint_path))+".pb"

            with tf.gfile.GFile(frozen_model_path, "wb") as f:
                f.write(output_graph_def.SerializeToString())
        print("Frozen model saved at " + frozen_model_path)
Пример #7
0
def main_train(args):
    voc, char2id, id2char = get_vocabulary(voc_type=args.voc_type)
    tf.set_random_seed(1)
    # Build graph
    input_train_images = tf.placeholder(
        dtype=tf.float32,
        shape=[args.train_batch_size, args.height, args.width, 3],
        name="input_train_images")
    input_train_images_width = tf.placeholder(dtype=tf.float32,
                                              shape=[args.train_batch_size],
                                              name="input_train_width")
    input_train_labels = tf.placeholder(
        dtype=tf.int32,
        shape=[args.train_batch_size, args.max_len],
        name="input_train_labels")
    input_train_gauss_labels = tf.placeholder(
        dtype=tf.float32,
        shape=[args.train_batch_size, args.max_len, 6, 40],
        name="input_train_gauss_labels")  # better way wanted!!!
    input_train_gauss_tags = tf.placeholder(
        dtype=tf.float32,
        shape=[args.train_batch_size, args.max_len],
        name="input_train_gauss_tags")
    input_train_gauss_params = tf.placeholder(
        dtype=tf.float32,
        shape=[args.train_batch_size, args.max_len, 4],
        name="input_train_gauss_params")
    input_train_labels_mask = tf.placeholder(
        dtype=tf.int32,
        shape=[args.train_batch_size, args.max_len],
        name="input_train_labels_mask")
    input_train_char_sizes = tf.placeholder(
        dtype=tf.float32,
        shape=[args.train_batch_size, args.max_len, 2],
        name="input_train_char_sizes")
    input_train_gauss_mask = tf.placeholder(
        dtype=tf.float32,
        shape=[args.train_batch_size, args.max_len, 6, 40],
        name="input_train_gauss_mask")

    input_val_images = tf.placeholder(
        dtype=tf.float32,
        shape=[args.val_batch_size, args.height, args.width, 3],
        name="input_val_images")
    input_val_images_width = tf.placeholder(dtype=tf.float32,
                                            shape=[args.val_batch_size],
                                            name="input_val_width")
    input_val_labels = tf.placeholder(
        dtype=tf.int32,
        shape=[args.val_batch_size, args.max_len],
        name="input_val_labels")
    # input_val_gauss_labels = tf.placeholder(dtype=tf.float32, shape=[args.val_batch_size, args.max_len, args.height, args.width], name="input_val_gauss_labels")
    input_val_labels_mask = tf.placeholder(
        dtype=tf.int32,
        shape=[args.val_batch_size, args.max_len],
        name="input_val_labels_mask")

    sar_model = SARModel(num_classes=len(voc),
                         encoder_dim=args.encoder_sdim,
                         encoder_layer=args.encoder_layers,
                         decoder_dim=args.decoder_sdim,
                         decoder_layer=args.decoder_layers,
                         decoder_embed_dim=args.decoder_edim,
                         seq_len=args.max_len,
                         is_training=True,
                         att_loss_type=args.att_loss_type,
                         att_loss_weight=args.att_loss_weight)
    sar_model_val = SARModel(num_classes=len(voc),
                             encoder_dim=args.encoder_sdim,
                             encoder_layer=args.encoder_layers,
                             decoder_dim=args.decoder_sdim,
                             decoder_layer=args.decoder_layers,
                             decoder_embed_dim=args.decoder_edim,
                             seq_len=args.max_len,
                             is_training=False)
    train_model_infer, train_attention_weights, train_pred, train_attention_params = sar_model(
        input_train_images,
        input_train_labels,
        input_train_images_width,
        batch_size=args.train_batch_size,
        reuse=False)
    if args.att_loss_type == "kldiv":
        train_loss, train_recog_loss, train_att_loss = sar_model.loss(
            train_model_infer, train_attention_weights, input_train_labels,
            input_train_gauss_labels, input_train_labels_mask,
            input_train_gauss_tags)
    elif args.att_loss_type == "l1" or args.att_loss_type == "l2":
        # train_loss, train_recog_loss, train_att_loss = sar_model.loss(train_model_infer, train_attention_params, input_train_labels, input_train_gauss_params, input_train_labels_mask, input_train_gauss_tags, input_train_char_sizes)
        train_loss, train_recog_loss, train_att_loss = sar_model.loss(
            train_model_infer, train_attention_weights, input_train_labels,
            input_train_gauss_labels, input_train_labels_mask,
            input_train_gauss_tags)
    elif args.att_loss_type == 'ce':
        train_loss, train_recog_loss, train_att_loss = sar_model.loss(
            train_model_infer, train_attention_weights, input_train_labels,
            input_train_gauss_labels, input_train_labels_mask,
            input_train_gauss_tags, input_train_gauss_mask)
    elif args.att_loss_type == 'gausskldiv':
        train_loss, train_recog_loss, train_att_loss = sar_model.loss(
            train_model_infer, train_attention_params, input_train_labels,
            input_train_gauss_params, input_train_labels_mask,
            input_train_gauss_tags)

    else:
        print("Unimplemented loss type {}".format(args.att_loss_dtype))
        exit(-1)
    val_model_infer, val_attention_weights, val_pred, _ = sar_model_val(
        input_val_images,
        input_val_labels,
        input_val_images_width,
        batch_size=args.val_batch_size,
        reuse=True)

    train_data_list = get_data(args.train_data_dir,
                               args.voc_type,
                               args.max_len,
                               args.height,
                               args.width,
                               args.train_batch_size,
                               args.workers,
                               args.keep_ratio,
                               with_aug=args.aug)

    val_data_gen = evaluator_data.Evaluator(lmdb_data_dir=args.test_data_dir,
                                            batch_size=args.val_batch_size,
                                            height=args.height,
                                            width=args.width,
                                            max_len=args.max_len,
                                            keep_ratio=args.keep_ratio,
                                            voc_type=args.voc_type)
    val_data_gen.reset()

    global_step = tf.get_variable(name='global_step',
                                  initializer=tf.constant(0),
                                  trainable=False)

    learning_rate = tf.train.piecewise_constant(global_step, args.decay_bound,
                                                args.lr_stage)
    batch_norm_updates_op = tf.group(tf.get_collection(
        tf.GraphKeys.UPDATE_OPS))

    # Save summary
    os.makedirs(args.checkpoints, exist_ok=True)
    tf.summary.scalar(name='train_loss', tensor=train_loss)
    tf.summary.scalar(name='train_recog_loss', tensor=train_recog_loss)
    tf.summary.scalar(name='train_att_loss', tensor=train_att_loss)
    # tf.summary.scalar(name='val_att_loss', tensor=val_att_loss)
    tf.summary.scalar(name='learning_rate', tensor=learning_rate)

    merge_summary_op = tf.summary.merge_all()

    train_start_time = time.strftime('%Y-%m-%d-%H-%M-%S',
                                     time.localtime(time.time()))
    model_name = 'sar_{:s}.ckpt'.format(str(train_start_time))
    model_save_path = os.path.join(args.checkpoints, model_name)
    best_model_save_path = os.path.join(args.checkpoints, 'best_model',
                                        model_name)
    variable_averages = tf.train.ExponentialMovingAverage(0.997, global_step)
    variables_averages_op = variable_averages.apply(tf.trainable_variables())

    with tf.control_dependencies(
        [variables_averages_op, batch_norm_updates_op]):
        optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
        grads = optimizer.compute_gradients(train_loss)
        if args.grad_clip > 0:
            print("With Gradients clipped!")
            for idx, (grad, var) in enumerate(grads):
                grads[idx] = (tf.clip_by_norm(grad, args.grad_clip), var)
        train_op = optimizer.apply_gradients(grads, global_step=global_step)

    saver = tf.train.Saver(tf.global_variables(), max_to_keep=1)
    best_saver = tf.train.Saver(tf.global_variables(), max_to_keep=1)
    summary_writer = tf.summary.FileWriter(args.checkpoints)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    log_file = open(os.path.join(args.checkpoints, args.checkpoints + ".log"),
                    "w")
    with tf.Session(config=config) as sess:
        summary_writer.add_graph(sess.graph)
        start_iter = 0
        if args.resume == True and args.pretrained != '':
            print('Restore model from {:s}'.format(args.pretrained))
            ckpt_state = tf.train.get_checkpoint_state(args.pretrained)
            model_path = os.path.join(
                args.pretrained,
                os.path.basename(ckpt_state.model_checkpoint_path))
            saver.restore(sess=sess, save_path=model_path)
            start_iter = sess.run(tf.train.get_global_step())
        elif args.resume == False and args.pretrained != '':
            print('Restore pretrained model from {:s}'.format(args.pretrained))
            ckpt_state = tf.train.get_checkpoint_state(args.pretrained)
            model_path = os.path.join(
                args.pretrained,
                os.path.basename(ckpt_state.model_checkpoint_path))
            saver.restore(sess=sess, save_path=model_path)
            sess.run(tf.assign(global_step, 0))
        else:
            print('Training from scratch')
            init = tf.global_variables_initializer()
            sess.run(init)

        # Evaluate the model first
        val_pred_value_all = []
        val_labels = []
        for eval_iter in range(val_data_gen.num_samples //
                               args.val_batch_size):

            val_data = val_data_gen.get_batch()
            if val_data is None:
                break
            print("Evaluation: [{} / {}]".format(
                eval_iter, (val_data_gen.num_samples // args.val_batch_size)))
            val_pred_value = sess.run(val_pred,
                                      feed_dict={
                                          input_val_images: val_data[0],
                                          input_val_labels: val_data[1],
                                          input_val_images_width: val_data[5],
                                          input_val_labels_mask: val_data[2]
                                      })
            val_pred_value_all.extend(val_pred_value)
            val_labels.extend(val_data[4])

        val_data_gen.reset()
        val_metrics_result = calc_metrics(idx2label(
            np.array(val_pred_value_all)),
                                          val_labels,
                                          metrics_type="accuracy")
        print("Evaluation Before training: Test accuracy {:3f}".format(
            val_metrics_result))
        val_best_acc = val_metrics_result

        while start_iter < args.iters:
            start_iter += 1
            train_data = get_batch_data(train_data_list, args.train_batch_size)
            _, train_loss_value, train_recog_loss_value, train_att_loss_value, train_pred_value = sess.run(
                [
                    train_op, train_loss, train_recog_loss, train_att_loss,
                    train_pred
                ],
                feed_dict={
                    input_train_images: train_data[0],
                    input_train_labels: train_data[1],
                    input_train_gauss_labels: train_data[2],
                    input_train_gauss_params: train_data[7],
                    input_train_labels_mask: train_data[3],
                    input_train_images_width: train_data[5],
                    input_train_gauss_tags: train_data[6],
                    input_train_char_sizes: train_data[8],
                    input_train_gauss_mask: train_data[9]
                })

            if start_iter % args.log_iter == 0:
                print(
                    "Iter {} train loss= {:3f} (recog loss= {:3f} att loss= {:3f})"
                    .format(start_iter, train_loss_value,
                            train_recog_loss_value, train_att_loss_value))
                log_file.write(
                    "Iter {} train loss= {:3f} (recog loss= {:3f} att loss= {:3f})"
                    .format(start_iter, train_loss_value,
                            train_recog_loss_value, train_att_loss_value))
            if start_iter % args.summary_iter == 0:
                merge_summary_value = sess.run(merge_summary_op,
                                               feed_dict={
                                                   input_train_images:
                                                   train_data[0],
                                                   input_train_labels:
                                                   train_data[1],
                                                   input_train_gauss_labels:
                                                   train_data[2],
                                                   input_train_gauss_params:
                                                   train_data[7],
                                                   input_train_labels_mask:
                                                   train_data[3],
                                                   input_train_images_width:
                                                   train_data[5],
                                                   input_train_gauss_tags:
                                                   train_data[6],
                                                   input_train_char_sizes:
                                                   train_data[8],
                                                   input_train_gauss_mask:
                                                   train_data[9]
                                               })

                summary_writer.add_summary(summary=merge_summary_value,
                                           global_step=start_iter)
                if start_iter % args.eval_iter == 0:
                    val_pred_value_all = []
                    val_labels = []
                    for eval_iter in range(val_data_gen.num_samples //
                                           args.val_batch_size):
                        val_data = val_data_gen.get_batch()
                        if val_data is None:
                            break
                        print("Evaluation: [{} / {}]".format(
                            eval_iter,
                            (val_data_gen.num_samples // args.val_batch_size)))
                        val_pred_value = sess.run(val_pred,
                                                  feed_dict={
                                                      input_val_images:
                                                      val_data[0],
                                                      input_val_labels:
                                                      val_data[1],
                                                      input_val_labels_mask:
                                                      val_data[2],
                                                      input_val_images_width:
                                                      val_data[5]
                                                  })
                        val_pred_value_all.extend(val_pred_value)
                        val_labels.extend(val_data[4])

                    val_data_gen.reset()
                    train_metrics_result = calc_metrics(
                        idx2label(train_pred_value),
                        train_data[4],
                        metrics_type="accuracy")
                    val_metrics_result = calc_metrics(idx2label(
                        np.array(val_pred_value_all)),
                                                      val_labels,
                                                      metrics_type="accuracy")
                    print(
                        "Evaluation Iter {} train accuracy: {:3f} test accuracy {:3f}"
                        .format(start_iter, train_metrics_result,
                                val_metrics_result))
                    log_file.write(
                        "Evaluation Iter {} train accuracy: {:3f} test accuracy {:3f}\n"
                        .format(start_iter, train_metrics_result,
                                val_metrics_result))

                    if val_metrics_result >= val_best_acc:
                        print("Better results! Save checkpoitns to {}".format(
                            best_model_save_path))
                        val_best_acc = val_metrics_result
                        best_saver.save(sess,
                                        best_model_save_path,
                                        global_step=global_step)

            if start_iter % args.save_iter == 0:
                print("Iter {} save to checkpoint".format(start_iter))
                saver.save(sess, model_save_path, global_step=global_step)
    log_file.close()
Пример #8
0
def generator(lmdb_dir,
              input_height,
              input_width,
              batch_size,
              max_len,
              voc_type,
              keep_ratio=True,
              with_aug=True):
    env = lmdb.open(lmdb_dir, max_readers=32, readonly=True)
    txn = env.begin()

    num_samples = int(txn.get(b"num-samples").decode())
    print("There are {} images in {}".format(num_samples, lmdb_dir))
    index = np.arange(0, num_samples)  # TODO check index is reliable

    voc, char2id, id2char = get_vocabulary(voc_type)
    is_lowercase = (voc_type == 'LOWERCASE')

    batch_images = []
    batch_images_width = []
    batch_labels = []
    batch_lengths = []
    batch_masks = []
    batch_labels_str = []

    while True:
        np.random.shuffle(index)
        for i in index:
            i += 1
            try:
                image_key = b'image-%09d' % i
                label_key = b'label-%09d' % i

                imgbuf = txn.get(image_key)
                buf = six.BytesIO()
                buf.write(imgbuf)
                buf.seek(0)

                img_pil = Image.open(buf).convert('RGB')
                img = np.array(img_pil)
                word = txn.get(label_key).decode()
                if is_lowercase:
                    word = word.lower()
                H, W, C = img.shape

                # Rotate the vertical images
                if H > 1.1 * W:
                    img = np.rot90(img)
                    H, W = W, H

                # Resize the images
                img_resize = np.zeros((input_height, input_width, C),
                                      dtype=np.uint8)

                # Data augmentation
                if with_aug:
                    ratn = random.randint(0, 4)
                    if ratn == 0:
                        rand_reg = random.random() * 30 - 15
                        img = rotate_img(img, rand_reg)

                if keep_ratio:
                    new_width = int((1.0 * H / input_height) * input_width)
                    new_width = new_width if new_width < input_width else input_width
                    new_width = new_width if new_width >= input_height else input_height
                    new_height = input_height
                    img = cv2.resize(img, (new_width, new_height))
                    img_resize[:new_height, :new_width, :] = img.copy()
                else:
                    new_width = input_width
                    img_resize = cv2.resize(img, (input_width, input_height))

                label = np.full((max_len), char2id['PAD'], dtype=np.int)
                label_mask = np.full((max_len), 0, dtype=np.int)
                label_list = []
                for char in word:
                    if char in char2id:
                        label_list.append(char2id[char])
                    else:
                        label_list.append(char2id['UNK'])

                if len(label_list) > (max_len - 1):
                    label_list = label_list[:(max_len - 1)]
                label_list = label_list + [char2id['EOS']]
                label[:len(label_list)] = np.array(label_list)

                if label.shape[0] <= 0:
                    continue

                label_len = len(label_list)
                label_mask[:label_len] = 1

                batch_images.append(img_resize)
                batch_images_width.append(new_width)
                batch_labels.append(label)
                batch_masks.append(label_mask)
                batch_lengths.append(label_len)
                batch_labels_str.append(word)

                assert len(batch_images) == len(batch_labels) == len(
                    batch_lengths)

                if len(batch_images) == batch_size:
                    yield np.array(batch_images), np.array(
                        batch_labels), np.array(batch_masks), np.array(
                            batch_lengths), batch_labels_str, np.array(
                                batch_images_width)
                    batch_images = []
                    batch_images_width = []
                    batch_labels = []
                    batch_masks = []
                    batch_lengths = []
                    batch_labels_str = []

            except Exception as e:
                print(e)
                print("Error in %d" % i)
                continue
Пример #9
0
    def get_batch(self):
        voc, char2id, id2char = get_vocabulary(self.voc_type)
        is_lowercase = (self.voc_type == 'LOWERCASE')

        batch_images = []
        batch_images_width = []
        batch_labels = []
        batch_lengths = []
        batch_masks = []
        batch_labels_str = []
        if (self.index + self.batch_size - 1) > self.num_samples:
            self.reset()
            return None

        # for i in range(self.index, self.index + self.batch_size - 1):

        while len(batch_images) < self.batch_size:
            try:
                image_key = b'image-%09d' % self.index
                label_key = b'label-%09d' % self.index
                self.index += 1

                imgbuf = self.txn.get(image_key)
                buf = six.BytesIO()
                buf.write(imgbuf)
                buf.seek(0)

                img_pil = Image.open(buf).convert('RGB')
                img = np.array(img_pil)
                word = self.txn.get(label_key).decode()
                if is_lowercase:
                    word = word.lower()
                H, W, C = img.shape

                # Rotate the vertical images
                if H > 1.1 * W:
                    img = np.rot90(img)
                    H, W = W, H

                # Resize the images
                img_resize = np.zeros((self.height, self.width, C),
                                      dtype=np.uint8)

                if self.keep_ratio:
                    new_width = int((1.0 * H / self.height) * self.width)
                    new_width = new_width if new_width < self.width else self.width
                    new_width = new_width if new_width >= self.height else self.height
                    new_height = self.height
                    img = cv2.resize(img, (new_width, new_height))
                    img_resize[:new_height, :new_width, :] = img.copy()
                else:
                    new_width = self.width
                    img_resize = cv2.resize(img, (self.width, self.height))

                label = np.full((self.max_len), char2id['PAD'], dtype=np.int)
                label_mask = np.full((self.max_len), 0, dtype=np.int)
                label_list = []
                for char in word:
                    if char in char2id:
                        label_list.append(char2id[char])
                    else:
                        label_list.append(char2id['UNK'])

                if len(label_list) > (self.max_len - 1):
                    label_list = label_list[:(self.max_len - 1)]
                label_list = label_list + [char2id['EOS']]
                label[:len(label_list)] = np.array(label_list)

                if label.shape[0] <= 0:
                    continue

                label_len = len(label_list)
                label_mask[:label_len] = 1

                batch_images.append(img_resize)
                batch_images_width.append(new_width)
                batch_labels.append(label)
                batch_masks.append(label_mask)
                batch_lengths.append(label_len)
                batch_labels_str.append(word)

                assert len(batch_images) == len(batch_labels) == len(
                    batch_lengths)

            except Exception as e:
                print(e)
                print("Error in %d" % self.index)
                continue

        return np.array(batch_images), np.array(batch_labels), np.array(
            batch_masks), np.array(batch_lengths), batch_labels_str, np.array(
                batch_images_width)
Пример #10
0
def main_train(args):
    voc, char2id, id2char = get_vocabulary(voc_type=args.voc_type)

    # Build graph
    input_train_images = tf.placeholder(
        dtype=tf.float32,
        shape=[args.train_batch_size, args.height, args.width, 3],
        name="input_train_images")
    input_train_images_width = tf.placeholder(dtype=tf.float32,
                                              shape=[args.train_batch_size],
                                              name="input_train_width")
    input_train_labels = tf.placeholder(
        dtype=tf.int32,
        shape=[args.train_batch_size, args.max_len],
        name="input_train_labels")
    input_train_labels_mask = tf.placeholder(
        dtype=tf.int32,
        shape=[args.train_batch_size, args.max_len],
        name="input_train_labels_mask")

    input_val_images = tf.placeholder(
        dtype=tf.float32,
        shape=[args.val_batch_size, args.height, args.width, 3],
        name="input_val_images")
    input_val_images_width = tf.placeholder(dtype=tf.float32,
                                            shape=[args.val_batch_size],
                                            name="input_val_width")
    input_val_labels = tf.placeholder(
        dtype=tf.int32,
        shape=[args.val_batch_size, args.max_len],
        name="input_val_labels")
    input_val_labels_mask = tf.placeholder(
        dtype=tf.int32,
        shape=[args.val_batch_size, args.max_len],
        name="input_val_labels_mask")

    sar_model = SARModel(num_classes=len(voc),
                         encoder_dim=args.encoder_sdim,
                         encoder_layer=args.encoder_layers,
                         decoder_dim=args.decoder_sdim,
                         decoder_layer=args.decoder_layers,
                         decoder_embed_dim=args.decoder_edim,
                         seq_len=args.max_len,
                         is_training=True)
    sar_model_val = SARModel(num_classes=len(voc),
                             encoder_dim=args.encoder_sdim,
                             encoder_layer=args.encoder_layers,
                             decoder_dim=args.decoder_sdim,
                             decoder_layer=args.decoder_layers,
                             decoder_embed_dim=args.decoder_edim,
                             seq_len=args.max_len,
                             is_training=False)
    train_model_infer, train_attention_weights, train_pred = sar_model(
        input_train_images,
        input_train_labels,
        input_train_images_width,
        batch_size=args.train_batch_size,
        reuse=False)
    train_loss = sar_model.loss(train_model_infer, input_train_labels,
                                input_train_labels_mask)

    val_model_infer, val_attention_weights, val_pred = sar_model_val(
        input_val_images,
        input_val_labels,
        input_val_images_width,
        batch_size=args.val_batch_size,
        reuse=True)
    val_loss = sar_model_val.loss(val_model_infer, input_val_labels,
                                  input_val_labels_mask)

    train_data_list = get_data(args.train_data_dir,
                               args.train_data_gt,
                               args.voc_type,
                               args.max_len,
                               args.num_train,
                               args.height,
                               args.width,
                               args.train_batch_size,
                               args.workers,
                               args.keep_ratio,
                               with_aug=args.aug)

    val_data_list = get_data(args.test_data_dir,
                             args.test_data_gt,
                             args.voc_type,
                             args.max_len,
                             args.num_train,
                             args.height,
                             args.width,
                             args.val_batch_size,
                             args.workers,
                             args.keep_ratio,
                             with_aug=False)

    global_step = tf.get_variable(name='global_step',
                                  initializer=tf.constant(0),
                                  trainable=False)
    learning_rate = tf.train.exponential_decay(learning_rate=args.lr,
                                               global_step=global_step,
                                               decay_steps=args.decay_iter,
                                               decay_rate=args.weight_decay,
                                               staircase=True)

    batch_norm_updates_op = tf.group(
        *tf.get_collection(tf.GraphKeys.UPDATE_OPS))
    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
    grads = optimizer.compute_gradients(train_loss)
    apply_gradient_op = optimizer.apply_gradients(grads,
                                                  global_step=global_step)

    # Save summary
    os.makedirs(args.checkpoints, exist_ok=True)
    tf.summary.scalar(name='train_loss', tensor=train_loss)
    tf.summary.scalar(name='val_loss', tensor=val_loss)
    tf.summary.scalar(name='learning_rate', tensor=learning_rate)

    merge_summary_op = tf.summary.merge_all()

    train_start_time = time.strftime('%Y-%m-%d-%H-%M-%S',
                                     time.localtime(time.time()))
    model_name = 'sar_{:s}.ckpt'.format(str(train_start_time))
    model_save_path = os.path.join(args.checkpoints, model_name)
    variable_averages = tf.train.ExponentialMovingAverage(0.997, global_step)
    variables_averages_op = variable_averages.apply(tf.trainable_variables())

    with tf.control_dependencies(
        [variables_averages_op, apply_gradient_op, batch_norm_updates_op]):
        train_op = tf.no_op(name='train_op')

    saver = tf.train.Saver(tf.global_variables(), max_to_keep=1)
    summary_writer = tf.summary.FileWriter(args.checkpoints)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    with tf.Session(config=config) as sess:
        summary_writer.add_graph(sess.graph)
        start_iter = 0
        if args.resume == True and args.pretrained != '':
            print('Restore model from {:s}'.format(args.pretrained))
            ckpt_state = tf.train.get_checkpoint_state(args.pretrained)
            model_path = os.path.join(
                args.pretrained,
                os.path.basename(ckpt_state.model_checkpoint_path))
            saver.restore(sess=sess, save_path=model_path)
            start_iter = sess.run(tf.train.get_global_step())
        else:
            print('Training from scratch')
            init = tf.global_variables_initializer()
            sess.run(init)

        while start_iter < args.iters:
            start_iter += 1
            train_data = get_batch_data(train_data_list, args.train_batch_size)
            _, train_loss_value, train_pred_value = sess.run(
                [train_op, train_loss, train_pred],
                feed_dict={
                    input_train_images: train_data[0],
                    input_train_labels: train_data[1],
                    input_train_labels_mask: train_data[2],
                    input_train_images_width: train_data[4]
                })

            if start_iter % args.log_iter == 0:
                print("Iter {} train loss= {:3f}".format(
                    start_iter, train_loss_value))

            if start_iter % args.summary_iter == 0:
                val_data = get_batch_data(val_data_list, args.val_batch_size)

                merge_summary_value, val_pred_value, val_loss_value = sess.run(
                    [merge_summary_op, val_pred, val_loss],
                    feed_dict={
                        input_train_images: train_data[0],
                        input_train_labels: train_data[1],
                        input_train_labels_mask: train_data[2],
                        input_train_images_width: train_data[4],
                        input_val_images: val_data[0],
                        input_val_labels: val_data[1],
                        input_val_labels_mask: val_data[2],
                        input_val_images_width: val_data[4]
                    })

                summary_writer.add_summary(summary=merge_summary_value,
                                           global_step=start_iter)
                if start_iter % args.eval_iter == 0:
                    print("#" * 80)
                    print("train prediction \t train labels ")
                    for result, gt in zip(idx2label(train_pred_value),
                                          train_data[3]):
                        print("{} \t {}".format(result, gt))
                    print("#" * 80)
                    print("test prediction \t test labels ")
                    for result, gt in zip(
                            idx2label(val_pred_value)[:32], val_data[3][:32]):
                        print("{} \t {}".format(result, gt))
                    print("#" * 80)

                    train_metrics_result = calc_metrics(
                        idx2label(train_pred_value),
                        train_data[3],
                        metrics_type="accuracy")
                    val_metrics_result = calc_metrics(
                        idx2label(val_pred_value),
                        val_data[3],
                        metrics_type="accuracy")
                    print(
                        "Evaluation Iter {} test loss: {:3f} train accuracy: {:3f} test accuracy {:3f}"
                        .format(start_iter, val_loss_value,
                                train_metrics_result, val_metrics_result))
            if start_iter % args.save_iter == 0:
                print("Iter {} save to checkpoint".format(start_iter))
                saver.save(sess, model_save_path, global_step=global_step)
Пример #11
0
def generator(lmdb_dir,
              input_height,
              input_width,
              batch_size,
              max_len,
              voc_type,
              keep_ratio=True,
              with_aug=True):
    env = lmdb.open(lmdb_dir, max_readers=32, readonly=True)
    txn = env.begin()

    if txn.get(b"num-samples") is None:  # SynthText800K is still generating
        num_samples = 3000000
    else:
        num_samples = int(txn.get(b"num-samples").decode())
    print("There are {} images in {}".format(num_samples, lmdb_dir))
    index = np.arange(0, num_samples)  # TODO check index is reliable

    voc, char2id, id2char = get_vocabulary(voc_type)
    is_lowercase = (voc_type == 'LOWERCASE')

    batch_images = []
    batch_images_width = []
    batch_labels = []
    batch_gausses = []
    batch_params = []
    batch_gauss_tags = []
    batch_gauss_masks = []
    batch_lengths = []
    batch_masks = []
    batch_labels_str = []
    batch_char_size = []
    batch_char_bbs = []

    while True:
        np.random.shuffle(index)
        for i in index:
            i += 1
            try:
                image_key = b'image-%09d' % i
                label_key = b'label-%09d' % i
                char_key = b'char-%09d' % i

                imgbuf = txn.get(image_key)
                buf = six.BytesIO()
                buf.write(imgbuf)
                buf.seek(0)

                img_pil = Image.open(buf).convert('RGB')
                img = np.array(img_pil)
                word = txn.get(label_key).decode()
                wo_char = (txn.get(char_key) is None)
                if wo_char:
                    charBBs = np.zeros(dtype=np.float32,
                                       shape=[len(word), 4, 2])
                else:
                    charBBs = np.array([
                        math.ceil(float(c))
                        for c in txn.get(char_key).decode().split()
                    ]).reshape([-1, 4, 2]).astype(np.float32)

                num_char_bb = charBBs.shape[0]
                if is_lowercase:
                    word = word.lower()
                H, W, C = img.shape

                charBBs[:, :, 0] = np.clip(charBBs[:, :, 0], 0, W)
                charBBs[:, :, 1] = np.clip(charBBs[:, :, 1], 0, H)

                # Rotate the vertical images
                if H > 1.1 * W:
                    img = np.rot90(img)
                    H, W = W, H
                    charBBs = charBBs[:, :, ::-1]
                    charBBs[:, :, 1] = H - charBBs[:, :, 1]

                # Resize the images
                img_resize = np.zeros((input_height, input_width, C),
                                      dtype=np.uint8)

                # Data augmentation
                if with_aug:
                    blur_aug = imgaug.augmenters.blur.GaussianBlur(sigma=(0,
                                                                          1.0))
                    contrast_aug = imgaug.augmenters.contrast.LinearContrast(
                        (0.75, 1.0))
                    affine_aug = imgaug.augmenters.geometric.PiecewiseAffine(
                        scale=(0.01, 0.02), mode='constant', cval=0)

                    # Gaussian Blur
                    ratn = random.randint(0, 1)
                    # img = cv2.GaussianBlur(img, (3, 3), 2)
                    if ratn == 0:
                        img = blur_aug.augment_image(img)
                    # Gaussian noise
                    # img = adding_guass(img)

                    # Contrast
                    ratn = random.randint(0, 1)
                    if ratn == 0:
                        img = contrast_aug.augment_image(img)

                    # Affine
                    ratn = random.randint(0, 1)
                    if ratn == 0:
                        img = affine_aug.augment_image(img)

                    # Rotation
                    ratn = random.randint(0, 1)
                    if ratn == 0:
                        # rand_reg = random.random() * 30 - 15
                        rand_reg = random.randint(-15, 15)
                        img, charBBs, (W, H) = rotate_img(img,
                                                          rand_reg,
                                                          BBs=charBBs)

                if keep_ratio:
                    new_width = int((1.0 * H / input_height) * input_width)
                    new_width = new_width if new_width < input_width else input_width
                    new_width = new_width if new_width >= input_height else input_height
                    new_height = input_height
                    img = cv2.resize(img, (new_width, new_height))
                    img_resize[:new_height, :new_width, :] = img.copy()
                else:
                    new_width = input_width
                    new_height = input_height
                    img_resize = cv2.resize(img, (input_width, input_height))

                ratio_w = float(new_width) / float(W)
                ratio_h = float(new_height) / float(H)

                charBBs[:, :, 0] = charBBs[:, :, 0] * ratio_w
                charBBs[:, :, 1] = charBBs[:, :, 1] * ratio_h

                # visualization for debugging rotate augmentation
                # img_debug = img_resize.copy()
                # for bb in charBBs:
                #     img_debug = cv2.polylines(img_debug, [bb.astype(np.int32).reshape((-1, 1, 2))], True,
                #                               color=(255, 255, 0), thickness=1)
                # cv2.imwrite("./char_bb_vis/{}.jpg".format(i), img_debug)

                # feature_map_w = 40.
                # feature_map_h = 6.
                # feature_map_w = input_width
                # feature_map_h = input_height
                # ratio_w_f = feature_map_w / input_width
                # ratio_h_f = feature_map_h / input_height
                # charBBs[:, :, 0] = charBBs[:, :, 0] * ratio_w_f
                # charBBs[:, :, 1] = charBBs[:, :, 1] * ratio_h_f

                label = np.full((max_len), char2id['PAD'], dtype=np.int)
                label_mask = np.full((max_len), 0, dtype=np.int)
                label_list = []
                for char in word:
                    if char in char2id:
                        label_list.append(char2id[char])
                    else:
                        label_list.append(char2id['UNK'])

                if len(label_list) > (max_len - 1):
                    label_list = label_list[:(max_len - 1)]
                    num_char_bb = max_len - 1
                label_list = label_list + [char2id['EOS']]
                label[:len(label_list)] = np.array(label_list)

                if label.shape[0] <= 0:
                    continue

                label_len = len(label_list)
                label_mask[:label_len] = 1
                if (label_len - 1) != num_char_bb:
                    print(
                        "Unmatched between char bb and label length in index {}"
                        .format(i))
                    print(
                        "Information: label: {} label_length: {} num_char: {}".
                        format(word, label_len - 1, num_char_bb))
                    continue

                # Get gaussian distribution labels
                # gauss_labels = np.zeros(dtype=np.float32, shape=[max_len, int(feature_map_h), int(feature_map_w)]) # T * H * W
                gauss_labels = np.zeros(
                    dtype=np.float32,
                    shape=[max_len, input_height, input_width])  # T * H * W
                # gauss_mask = np.zeros(dtype=np.float32, shape=[max_len, int(feature_map_h), int(feature_map_w)]) # T * H * W
                gauss_mask = np.zeros(
                    dtype=np.float32,
                    shape=[max_len, input_height, input_width])  # T * H * W
                # distrib_params = []
                distrib_params = np.zeros(dtype=np.float32, shape=[max_len, 4])
                distrib_params[:, 2:] = 1.
                char_size = np.ones(dtype=np.float32, shape=[max_len, 2])
                gauass_tags = [0] * max_len
                charBBs = charBBs[:num_char_bb]
                if wo_char == False:
                    for i, BB in enumerate(charBBs):  # 4 * 2
                        # try:
                        # Here we use min bounding rectangles
                        min_rec, delta_x, delta_y = find_min_rectangle(
                            BB)  # 4 * 2
                        if delta_x < 2 or delta_y < 2:
                            param = get_distrib_params(BB)
                            distrib_params[i] = param
                            continue
                        gauss_distrib = get_gauss_distrib(
                            (delta_y, delta_x))  # delta_y * delta_x
                        # param = get_distrib_params(BB)
                        param = estim_gauss_params(gauss_distrib, delta_x,
                                                   delta_y)
                        # param[0] = param[0] / feature_map_w
                        # param[1] = param[1] / feature_map_h
                        # param[2] = param[2] / (0.25 * feature_map_w * feature_map_w)
                        # param[3] = param[3] / (0.25 * feature_map_h * feature_map_h)
                        # gauss_distrib = construct_gauss_distirb(param, delta_x, delta_y)
                        char_size[i][0] = delta_x
                        char_size[i][1] = delta_y
                        distrib_params[i] = param
                        # res_gauss = aff_gaussian(gauss_distrib, min_rec, BB, delta_x, delta_y) # delta_y * delta_x
                        res_gauss = gauss_distrib
                        gauass_tags[i] = 1.
                        if np.max(res_gauss) > 0.:
                            start_x, start_y = int(min_rec[0][0]), int(
                                min_rec[0][1])
                            end_x, end_y = start_x + delta_x, start_y + delta_y
                            end_x = end_x if end_x <= input_width else input_width
                            end_y = end_y if end_y <= input_height else input_height
                            gauss_labels[i, start_y:end_y,
                                         start_x:end_x] = res_gauss

                            ex_start_x = math.floor(start_x - 0.3 * delta_x)
                            ex_end_x = math.ceil(end_x + 0.3 * delta_x)
                            ex_start_y = math.floor(start_y - 0.3 * delta_y)
                            ex_end_y = math.ceil(end_y + 0.3 * delta_y)

                            ex_start_x = ex_start_x if ex_start_x >= 0 else 0
                            ex_start_y = ex_start_y if ex_start_y >= 0 else 0
                            ex_end_x = ex_end_x if ex_end_x <= input_width else input_width
                            ex_end_y = ex_end_y if ex_end_y <= input_height else input_height

                            gauss_mask[i,
                                       int(ex_start_y):int(ex_end_y),
                                       int(ex_start_x):int(ex_end_x)] = 1.
                            # gauss_mask[i, int(start_y):int(end_y), int(start_x):int(end_x)] = 1.

                        # except Exception as e:
                        #     print(e)
                    gauss_labels = sum_norm(
                        gauss_labels.reshape(
                            [gauss_labels.shape[0],
                             -1])).reshape([-1, input_height, input_width])

                # Reduce to feature map size
                gauss_labels = roi_sum(gauss_labels, target_h=6, target_w=40)
                gauss_mask = roi_max(gauss_mask, target_h=6, target_w=40)
                # distrib_params = np.array(distrib_params)

                batch_images.append(img_resize)
                batch_images_width.append(new_width)
                batch_labels.append(label)
                batch_gausses.append(gauss_labels)
                batch_params.append(distrib_params)
                batch_gauss_tags.append(gauass_tags)
                batch_gauss_masks.append(gauss_mask)
                batch_masks.append(label_mask)
                batch_lengths.append(label_len)
                batch_labels_str.append(word)
                batch_char_size.append(char_size)
                batch_char_bbs.append(charBBs)

                assert len(batch_images) == len(batch_labels) == len(
                    batch_lengths) == len(batch_gausses) == len(
                        batch_char_size)

                if len(batch_images) == batch_size:
                    yield np.array(batch_images), \
                          np.array(batch_labels), \
                          np.array(batch_gausses), \
                          np.array(batch_masks), \
                          np.array(batch_lengths), \
                          batch_labels_str, \
                          np.array(batch_images_width), \
                          np.array(batch_gauss_tags), \
                          np.array(batch_params), \
                          np.array(batch_char_size), \
                          np.array(batch_gauss_masks)
                    batch_images = []
                    batch_images_width = []
                    batch_labels = []
                    batch_gausses = []
                    batch_params = []
                    batch_gauss_tags = []
                    batch_gauss_masks = []
                    batch_masks = []
                    batch_lengths = []
                    batch_labels_str = []
                    batch_char_size = []

            except Exception as e:
                print(e)
                print("Error in %d" % i)
                continue