コード例 #1
0
def main(argv=None):
    import os
    os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu_list
    if not tf.gfile.Exists(FLAGS.checkpoint_path):
        tf.gfile.MkDir(FLAGS.checkpoint_path)
    else:
        if not FLAGS.restore:
            tf.gfile.DeleteRecursively(FLAGS.checkpoint_path)
            tf.gfile.MkDir(FLAGS.checkpoint_path)

    input_images = tf.placeholder(tf.float32,
                                  shape=[None, None, None, 3],
                                  name='input_images')
    input_seg_maps = tf.placeholder(tf.float32,
                                    shape=[None, None, None, 6],
                                    name='input_score_maps')
    input_training_masks = tf.placeholder(tf.float32,
                                          shape=[None, None, None, 1],
                                          name='input_training_masks')

    global_step = tf.get_variable('global_step', [],
                                  initializer=tf.constant_initializer(0),
                                  trainable=False)
    learning_rate = tf.train.exponential_decay(FLAGS.learning_rate,
                                               global_step,
                                               decay_steps=10000,
                                               decay_rate=0.94,
                                               staircase=True)
    # add summary
    tf.summary.scalar('learning_rate', learning_rate)
    # opt = tf.train.RMSPropOptimizer(learning_rate, decay=0.9, momentum=0.9)
    opt = tf.train.AdamOptimizer(learning_rate)
    # opt = tf.train.MomentumOptimizer(learning_rate, 0.9)

    # split
    input_images_split = tf.split(input_images, len(gpus))
    input_seg_maps_split = tf.split(input_seg_maps, len(gpus))
    input_training_masks_split = tf.split(input_training_masks, len(gpus))

    tower_grads = []
    reuse_variables = None
    for i, gpu_id in enumerate(gpus):
        with tf.device('/gpu:%d' % gpu_id):
            with tf.name_scope('model_%d' % gpu_id) as scope:
                iis = input_images_split[i]
                isegs = input_seg_maps_split[i]
                itms = input_training_masks_split[i]
                total_loss, model_loss = tower_loss(iis, isegs, itms,
                                                    reuse_variables)
                batch_norm_updates_op = tf.group(
                    *tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope))
                reuse_variables = True

                grads = opt.compute_gradients(total_loss)
                tower_grads.append(grads)

    grads = average_gradients(tower_grads)
    apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)

    summary_op = tf.summary.merge_all()
    # save moving average
    variable_averages = tf.train.ExponentialMovingAverage(
        FLAGS.moving_average_decay, global_step)
    variables_averages_op = variable_averages.apply(tf.trainable_variables())
    # batch norm updates
    with tf.control_dependencies(
        [variables_averages_op, apply_gradient_op, batch_norm_updates_op]):
        train_op = tf.no_op(name='train_op')

    saver = tf.train.Saver(tf.global_variables())
    summary_writer = tf.summary.FileWriter(FLAGS.checkpoint_path,
                                           tf.get_default_graph())

    init = tf.global_variables_initializer()

    if FLAGS.pretrained_model_path is not None:
        variable_restore_op = slim.assign_from_checkpoint_fn(
            FLAGS.pretrained_model_path,
            slim.get_trainable_variables(),
            ignore_missing_vars=True)
    gpu_options = tf.GPUOptions(allow_growth=True)
    #gpu_options=tf.GPUOptions(per_process_gpu_memory_fraction=0.75)
    with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options,
                                          allow_soft_placement=True)) as sess:
        if FLAGS.restore:
            logger.info('continue training from previous checkpoint')
            ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
            logger.debug(ckpt)
            saver.restore(sess, ckpt)
        else:
            sess.run(init)
            if FLAGS.pretrained_model_path is not None:
                variable_restore_op(sess)

        data_generator = data_provider.get_batch(
            num_workers=FLAGS.num_readers,
            input_size=FLAGS.input_size,
            batch_size=FLAGS.batch_size_per_gpu * len(gpus))

        start = time.time()
        for step in range(FLAGS.max_steps):
            data = next(data_generator)
            ml, tl, _ = sess.run(
                [model_loss, total_loss, train_op],
                feed_dict={
                    input_images: data[0],
                    input_seg_maps: data[2],
                    input_training_masks: data[3]
                })
            if np.isnan(tl):
                logger.error('Loss diverged, stop training')
                break

            if step % 10 == 0:
                avg_time_per_step = (time.time() - start) / 10
                avg_examples_per_second = (10 * FLAGS.batch_size_per_gpu *
                                           len(gpus)) / (time.time() - start)
                start = time.time()
                logger.info(
                    'Step {:06d}, model loss {:.4f}, total loss {:.4f}, {:.2f} seconds/step, {:.2f} examples/second'
                    .format(step, ml, tl, avg_time_per_step,
                            avg_examples_per_second))

            if step % FLAGS.save_checkpoint_steps == 0:
                saver.save(sess,
                           os.path.join(FLAGS.checkpoint_path, 'model.ckpt'),
                           global_step=global_step)

            if step % FLAGS.save_summary_steps == 0:
                _, tl, summary_str = sess.run(
                    [train_op, total_loss, summary_op],
                    feed_dict={
                        input_images: data[0],
                        input_seg_maps: data[2],
                        input_training_masks: data[3]
                    })
                summary_writer.add_summary(summary_str, global_step=step)
コード例 #2
0
            except Exception as e:
                traceback.print_exc()
                continue


def get_batch(num_workers, **kwargs):
    try:
        enqueuer = GeneratorEnqueuer(generator(**kwargs),
                                     use_multiprocessing=True)
        enqueuer.start(max_queue_size=24, workers=num_workers)
        generator_output = None
        while True:
            while enqueuer.is_running():
                if not enqueuer.queue.empty():
                    generator_output = enqueuer.queue.get()
                    break
                else:
                    time.sleep(0.01)
            yield generator_output
            generator_output = None
    finally:
        if enqueuer is not None:
            enqueuer.stop()


if __name__ == '__main__':
    gen = get_batch(num_workers=2, vis=True)
    while True:
        image, bbox, im_info = next(gen)
        logger.debug('done')
コード例 #3
0
def main(argv=None):

    import os
    os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu_list

    try:
        os.makedirs(FLAGS.output_dir)
    except OSError as e:
        if e.errno != 17:
            raise

    with tf.get_default_graph().as_default():
        input_images = tf.placeholder(tf.float32,
                                      shape=[None, None, None, 3],
                                      name='input_images')
        global_step = tf.get_variable('global_step', [],
                                      initializer=tf.constant_initializer(0),
                                      trainable=False)
        seg_maps_pred = model.model(input_images, is_training=False)

        variable_averages = tf.train.ExponentialMovingAverage(
            0.997, global_step)
        saver = tf.train.Saver(variable_averages.variables_to_restore())
        with tf.Session(config=tf.ConfigProto(
                allow_soft_placement=True)) as sess:
            ckpt_state = tf.train.get_checkpoint_state(FLAGS.checkpoint_path)
            model_path = os.path.join(
                FLAGS.checkpoint_path,
                os.path.basename(ckpt_state.model_checkpoint_path))
            logger.info('Restore from {}'.format(model_path))
            saver.restore(sess, model_path)

            im_fn_list = get_images()
            for im_fn in im_fn_list:
                im = cv2.imread(im_fn)[:, :, ::-1]
                logger.debug('image file:{}'.format(im_fn))

                start_time = time.time()
                im_resized, (ratio_h, ratio_w) = resize_image(im)
                h, w, _ = im_resized.shape
                # options = tf.RunOptions(trace_level = tf.RunOptions.FULL_TRACE)
                # run_metadata = tf.RunMetadata()
                timer = {'net': 0, 'pse': 0}
                start = time.time()
                seg_maps = sess.run(seg_maps_pred,
                                    feed_dict={input_images: [im_resized]})
                timer['net'] = time.time() - start
                # fetched_timeline = timeline.Timeline(run_metadata.step_stats)
                # chrome_trace = fetched_timeline.generate_chrome_trace_format()
                # with open(os.path.join(FLAGS.output_dir, os.path.basename(im_fn).split('.')[0]+'.json'), 'w') as f:
                #     f.write(chrome_trace)

                boxes, kernels, timer = detect(seg_maps=seg_maps,
                                               timer=timer,
                                               image_w=w,
                                               image_h=h)
                logger.info('{} : net {:.0f}ms, pse {:.0f}ms'.format(
                    im_fn, timer['net'] * 1000, timer['pse'] * 1000))

                if boxes is not None:
                    boxes = boxes.reshape((-1, 4, 2))
                    boxes[:, :, 0] /= ratio_w
                    boxes[:, :, 1] /= ratio_h
                    h, w, _ = im.shape
                    boxes[:, :, 0] = np.clip(boxes[:, :, 0], 0, w)
                    boxes[:, :, 1] = np.clip(boxes[:, :, 1], 0, h)

                duration = time.time() - start_time
                logger.info('[timing] {}'.format(duration))

                # save to file
                if boxes is not None:
                    res_file = os.path.join(
                        FLAGS.output_dir, '{}.txt'.format(
                            os.path.splitext(os.path.basename(im_fn))[0]))

                    with open(res_file, 'w') as f:
                        num = 0
                        for i in range(len(boxes)):
                            # to avoid submitting errors
                            box = boxes[i]
                            if np.linalg.norm(box[0] -
                                              box[1]) < 5 or np.linalg.norm(
                                                  box[3] - box[0]) < 5:
                                continue

                            num += 1

                            f.write('{},{},{},{},{},{},{},{}\r\n'.format(
                                box[0, 0], box[0, 1], box[1, 0], box[1, 1],
                                box[2, 0], box[2, 1], box[3, 0], box[3, 1]))
                            cv2.polylines(
                                im[:, :, ::-1],
                                [box.astype(np.int32).reshape((-1, 1, 2))],
                                True,
                                color=(255, 255, 0),
                                thickness=2)
                if not FLAGS.no_write_images:
                    img_path = os.path.join(FLAGS.output_dir,
                                            os.path.basename(im_fn))
                    cv2.imwrite(img_path, im[:, :, ::-1])

    #===========================================================================================================
    #Converting to 4-co-ordinates txt

    path = test_data_path + '/'  #input_images
    gt_path = output_dir + '/'  #8 co-ordinates txt
    out_path = APP_ROOT + '/output_label'  #4 co-ordinates txt

    if not os.path.exists(out_path):
        os.makedirs(out_path)
    else:
        shutil.rmtree(out_path)
        os.mkdir(out_path)

    files = os.listdir(path)
    files.sort()
    #files=files[:100]
    for file in files:
        _, basename = os.path.split(file)
        if basename.lower().split('.')[-1] not in ['jpg', 'png', 'jpeg']:
            continue
        stem, ext = os.path.splitext(basename)
        gt_file = os.path.join(gt_path + stem + '.txt')
        img_path = os.path.join(path, file)
        print('Reading image ' + os.path.splitext(file)[0])
        img = cv.imread(img_path)
        img_size = img.shape
        im_size_min = np.min(img_size[0:2])
        im_size_max = np.max(img_size[0:2])

        with open(gt_file, 'r') as f:
            lines = f.readlines()
        for line in lines:
            splitted_line = line.strip().lower().split(',')
            pt_x = np.zeros((4, 1))
            pt_y = np.zeros((4, 1))
            pt_x[0, 0] = int(float(splitted_line[0]))
            pt_y[0, 0] = int(float(splitted_line[1]))
            pt_x[1, 0] = int(float(splitted_line[2]))
            pt_y[1, 0] = int(float(splitted_line[3]))
            pt_x[2, 0] = int(float(splitted_line[4]))
            pt_y[2, 0] = int(float(splitted_line[5]))
            pt_x[3, 0] = int(float(splitted_line[6]))
            pt_y[3, 0] = int(float(splitted_line[7]))

            ind_x = np.argsort(pt_x, axis=0)
            pt_x = pt_x[ind_x]
            pt_y = pt_y[ind_x]

            if pt_y[0] < pt_y[1]:
                pt1 = (pt_x[0], pt_y[0])
                pt3 = (pt_x[1], pt_y[1])
            else:
                pt1 = (pt_x[1], pt_y[1])
                pt3 = (pt_x[0], pt_y[0])

            if pt_y[2] < pt_y[3]:
                pt2 = (pt_x[2], pt_y[2])
                pt4 = (pt_x[3], pt_y[3])
            else:
                pt2 = (pt_x[3], pt_y[3])
                pt4 = (pt_x[2], pt_y[2])

            xmin = int(min(pt1[0], pt2[0]))
            ymin = int(min(pt1[1], pt2[1]))
            xmax = int(max(pt2[0], pt4[0]))
            ymax = int(max(pt3[1], pt4[1]))

            if xmin < 0:
                xmin = 0
            if xmax > img_size[1] - 1:
                xmax = img_size[1] - 1
            if ymin < 0:
                ymin = 0
            if ymax > img_size[0] - 1:
                ymax = img_size[0] - 1

            with open(os.path.join(out_path, stem) + '.txt', 'a') as f:
                f.writelines(str(int(xmin)))
                f.writelines(" ")
                f.writelines(str(int(ymin)))
                f.writelines(" ")
                f.writelines(str(int(xmax)))
                f.writelines(" ")
                f.writelines(str(int(ymax)))
                f.writelines("\n")
コード例 #4
0
ファイル: end2end.py プロジェクト: 489397771/PSENET
def main(argv=None):
    import os
    # os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu_list
    t0 = time.time()
    try:
        os.makedirs(FLAGS.output_dir)
    except OSError as e:
        if e.errno != 17:
            raise

    im_fn_list = get_images()
    for im_fn in im_fn_list:
        points_list = []
        tf.reset_default_graph()
        with tf.get_default_graph().as_default():
            input_images = tf.placeholder(tf.float32,
                                          shape=[None, None, None, 3],
                                          name='input_images')
            global_step = tf.get_variable(
                'global_step', [],
                initializer=tf.constant_initializer(0),
                trainable=False)
            seg_maps_pred = model.model(input_images, is_training=False)

            variable_averages = tf.train.ExponentialMovingAverage(
                0.997, global_step)
            saver = tf.train.Saver(variable_averages.variables_to_restore())
            with tf.Session(config=tf.ConfigProto(
                    allow_soft_placement=True)) as sess:
                ckpt_state = tf.train.get_checkpoint_state(
                    FLAGS.checkpoint_path)
                model_path = os.path.join(
                    FLAGS.checkpoint_path,
                    os.path.basename(ckpt_state.model_checkpoint_path))

                logger.info('Restore from {}'.format(model_path))
                saver.restore(sess, model_path)

                im = cv2.imread(im_fn)[:, :, ::-1]
                draw_img = im[:, :, ::-1].copy()
                logger.debug('image file:{}'.format(im_fn))

                start_time = time.time()
                im_resized, (ratio_h, ratio_w) = resize_image(im)
                h, w, _ = im_resized.shape
                # options = tf.RunOptions(trace_level = tf.RunOptions.FULL_TRACE)
                # run_metadata = tf.RunMetadata()
                timer = {'net': 0, 'pse': 0}
                start = time.time()
                seg_maps = sess.run(seg_maps_pred,
                                    feed_dict={input_images: [im_resized]})
                timer['net'] = time.time() - start
                # fetched_timeline = timeline.Timeline(run_metadata.step_stats)
                # chrome_trace = fetched_timeline.generate_chrome_trace_format()
                # with open(os.path.join(FLAGS.output_dir, os.path.basename(im_fn).split('.')[0]+'.json'), 'w') as f:
                #     f.write(chrome_trace)

                boxes, kernels, timer = detect(seg_maps=seg_maps,
                                               timer=timer,
                                               image_w=w,
                                               image_h=h)
                logger.info('{} : net {:.0f}ms, pse {:.0f}ms'.format(
                    im_fn, timer['net'] * 1000, timer['pse'] * 1000))

                if boxes is not None:
                    boxes = boxes.reshape((-1, 4, 2))
                    boxes[:, :, 0] /= ratio_w
                    boxes[:, :, 1] /= ratio_h
                    h, w, _ = im.shape
                    boxes[:, :, 0] = np.clip(boxes[:, :, 0], 0, w)
                    boxes[:, :, 1] = np.clip(boxes[:, :, 1], 0, h)

                duration = time.time() - start_time
                logger.info('[timing] {}'.format(duration))

                # save to file
                if boxes is not None:
                    res_file = os.path.join(
                        FLAGS.output_dir, '{}.txt'.format(
                            os.path.splitext(os.path.basename(im_fn))[0]))

                    with open(res_file, 'w') as f:
                        num = 0
                        for i in range(len(boxes)):
                            # to avoid submitting errors
                            box = boxes[i]
                            if np.linalg.norm(box[0] -
                                              box[1]) < 5 or np.linalg.norm(
                                                  box[3] - box[0]) < 5:
                                continue

                            num += 1

                            f.write('{},{},{},{},{},{},{},{}\r\n'.format(
                                box[0, 0], box[0, 1], box[1, 0], box[1, 1],
                                box[2, 0], box[2, 1], box[3, 0], box[3, 1]))

                            yDim, xDim = im[:, :, ::-1].shape[:2]
                            if box[0, 0] > box[2, 0]:  # box point1在右下角,顺时针
                                pt1 = (max(1, box[2, 0]), max(1, box[2, 1]))
                                pt2 = (box[3, 0], box[3, 1])
                                pt3 = (min(box[0, 0],
                                           xDim - 2), min(yDim - 2, box[0, 1]))
                                pt4 = (box[1, 0], box[1, 1])
                            else:  # box point1在左下角, 顺时针
                                pt1 = (max(1, box[1, 0]), max(1, box[2, 1]))
                                pt2 = (box[2, 0], box[2, 1])
                                pt3 = (min(box[3, 0],
                                           xDim - 2), min(yDim - 2, box[3, 1]))
                                pt4 = (box[0, 0], box[0, 1])

                            points = [pt1, pt2, pt3, pt4]
                            points_list.append(points)

                            cv2.polylines(
                                im[:, :, ::-1],
                                [box.astype(np.int32).reshape((-1, 1, 2))],
                                True,
                                color=(255, 255, 0),
                                thickness=2)

        tf.reset_default_graph()
        keras.backend.clear_session()
        input = Input(shape=(32, None, 1), name='the_input')
        y_pred = dense_cnn(input, nclass)
        recognition_model = Model(input=input, outputs=y_pred)
        model_path = './recognition/...'
        recognition_model.load_weights(model_path)
        if os.path.exists(model_path):
            print('loading models')
        else:
            print('model do not exist')
            break

        j = 0
        txt_path = os.path.join(FLAGS.output_dir,
                                im_fn.split('/')[-1].split('.')[0])
        with open('{}.txt'.format(txt_path), 'a', encoding='utf-8') as outf:
            for points in points_list:
                j += 1
                pt1 = points[0]
                pt2 = points[1]
                pt3 = points[2]
                pt4 = points[3]
                degree = degrees(atan2(pt2[1] - pt1[1], pt2[0] - pt1[0]))
                text_img = dumpRotateImage(im[:, :, ::-1], degree, pt1, pt2,
                                           pt3, pt4)
                text_img = cv2.cvtColor(text_img, cv2.COLOR_BGR2GRAY)

                text_h, text_w = text_img.shape[:2]
                if text_h // text_w > 1:
                    continue
                dst_h = 32
                dst_w = text_w * dst_h // text_h
                text_img = cv2.resize(text_img, (dst_w, dst_h))
                X = text_img.reshape([1, 32, -1, 1])
                y_pred = recognition_model.predict(X)
                y_pred = y_pred[:, :, :]
                out = _decode(y_pred)
                img_PIL = Image.fromarray(
                    cv2.cvtColor(draw_img, cv2.COLOR_BGR2RGB))
                font = ImageFont.truetype('./utils/simsun.ttc', 12)
                fillColor = (255, 0, 0)
                draw = ImageDraw.Draw(img_PIL)
                if out is None:
                    out = ''
                draw.text(pt4, out, font=font, fill=fillColor)
                draw_img = cv2.cvtColor(np.asarray(img_PIL), cv2.COLOR_RGB2BGR)
                outf.write('{}. \t{}\n'.format(j, out))

            if not FLAGS.no_write_images:
                img_path = os.path.join(FLAGS.output_dir,
                                        os.path.basename(im_fn))
                cv2.imwrite(img_path, draw_img)

    print('total time = ', time.time() - t0)
コード例 #5
0
def main(argv=None):
    import os
    os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu_list

    try:
        os.makedirs(FLAGS.output_dir)
    except OSError as e:
        if e.errno != 17:
            raise

    if not os.path.isdir(os.path.join(FLAGS.output_dir, "crop")):
        os.makedirs(os.path.join(FLAGS.output_dir, "crop"))

    with tf.get_default_graph().as_default():
        input_images = tf.placeholder(tf.float32,
                                      shape=[None, None, None, 3],
                                      name='input_images')
        global_step = tf.get_variable('global_step', [],
                                      initializer=tf.constant_initializer(0),
                                      trainable=False)
        seg_maps_pred = model.model(input_images, is_training=False)

        variable_averages = tf.train.ExponentialMovingAverage(
            0.997, global_step)
        saver = tf.train.Saver(variable_averages.variables_to_restore())
        with tf.Session(config=tf.ConfigProto(
                allow_soft_placement=True)) as sess:
            #          ckpt_state = tf.train.get_checkpoint_state(FLAGS.checkpoint_path)
            #           model_path = os.path.join(FLAGS.checkpoint_path, os.path.basename(ckpt_state.model_checkpoint_path))
            #            logger.info('Restore from {}'.format(model_path))
            saver.restore(sess, FLAGS.checkpoint_path)

            im_fn_list = get_images()
            for im_fn in im_fn_list:
                im = cv2.imread(im_fn)[:, :, ::-1]
                logger.debug('image file:{}'.format(im_fn))

                start_time = time.time()
                im_resized, (ratio_h, ratio_w) = resize_image(im)
                h, w, _ = im_resized.shape
                # options = tf.RunOptions(trace_level = tf.RunOptions.FULL_TRACE)
                # run_metadata = tf.RunMetadata()
                timer = {'net': 0, 'pse': 0}
                start = time.time()
                seg_maps = sess.run(seg_maps_pred,
                                    feed_dict={input_images: [im_resized]})
                timer['net'] = time.time() - start
                # fetched_timeline = timeline.Timeline(run_metadata.step_stats)
                # chrome_trace = fetched_timeline.generate_chrome_trace_format()
                # with open(os.path.join(FLAGS.output_dir, os.path.basename(im_fn).split('.')[0]+'.json'), 'w') as f:
                #     f.write(chrome_trace)

                boxes, kernels, timer = detect(seg_maps=seg_maps,
                                               timer=timer,
                                               image_w=w,
                                               image_h=h)
                logger.info('{} : net {:.0f}ms, pse {:.0f}ms'.format(
                    im_fn, timer['net'] * 1000, timer['pse'] * 1000))

                if boxes is not None:
                    boxes = boxes.reshape((-1, 4, 2))
                    boxes[:, :, 0] /= ratio_w
                    boxes[:, :, 1] /= ratio_h
                    h, w, _ = im.shape
                    boxes[:, :, 0] = np.clip(boxes[:, :, 0], 0, w)
                    boxes[:, :, 1] = np.clip(boxes[:, :, 1], 0, h)

                duration = time.time() - start_time
                logger.info('[timing] {}'.format(duration))

                # save to file
                if boxes is not None:
                    res_file = os.path.join(
                        FLAGS.output_dir, '{}.txt'.format(
                            os.path.splitext(os.path.basename(im_fn))[0]))

                    with open(res_file, 'w') as f:
                        num = 0
                        for i in xrange(len(boxes)):
                            # to avoid submitting errors
                            box = boxes[i]
                            if np.linalg.norm(box[0] -
                                              box[1]) < 5 or np.linalg.norm(
                                                  box[3] - box[0]) < 5:
                                continue

                            num += 1

                            f.write('{},{},{},{},{},{},{},{}\r\n'.format(
                                box[0, 0], box[0, 1], box[1, 0], box[1, 1],
                                box[2, 0], box[2, 1], box[3, 0], box[3, 1]))
                            if not FLAGS.is_cropping:
                                cv2.polylines(
                                    im[:, :, ::-1],
                                    [box.astype(np.int32).reshape((-1, 1, 2))],
                                    True,
                                    color=(255, 255, 0),
                                    thickness=2)
                            else:
                                lt_x = box[2, 0]
                                lt_y = box[2, 1]
                                rt_x = box[3, 0]
                                rt_y = box[3, 1]
                                lb_x = box[1, 0]
                                lb_y = box[1, 1]
                                rb_x = box[0, 0]
                                rb_y = box[0, 1]
                                if lt_x > lb_x:
                                    lt_x = lb_x
                                if lt_y > rt_y:
                                    lt_y = rt_y
                                if rt_x < rb_x:
                                    rt_x = rb_x
                                if rt_y > lt_y:
                                    rt_y = lt_y
                                if lb_x > lt_x:
                                    lb_x = lt_x
                                if lb_y < rb_y:
                                    lb_y = rb_y
                                if rb_x < rt_x:
                                    rb_x = rt_x
                                if rb_y < lb_y:
                                    rb_y = lb_y


#                                 padding = 3
#                                 lt_x -= padding
#                                 lt_y -= padding
#                                 lb_x -= padding
#                                 lb_y += padding
#                                 rt_x += padding
#                                 rt_y -= padding
#                                 rb_x += padding
#                                 rb_y += padding
                                crop_img = im[int(lt_y):int(lb_y),
                                              int(lt_x):int(rt_x)]
                                cv2.imwrite(
                                    os.path.join(FLAGS.output_dir, "crop",
                                                 ("%d_" % i) +
                                                 os.path.basename(im_fn)),
                                    crop_img[:, :, ::-1])

                if not FLAGS.no_write_images:
                    img_path = os.path.join(FLAGS.output_dir,
                                            os.path.basename(im_fn))
                    cv2.imwrite(img_path, im[:, :, ::-1])
コード例 #6
0
ファイル: main.py プロジェクト: zzzdtz/ocr_invoice
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)
    input_dir = '../test_images'
    output_dir = '../test_images_result'
    img_path = '../test_images/IMG_20190107_103022.jpg'
    # img_path = '../test_images/IMG_20190108_162848.jpg'
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    g1 = tf.Graph()  # 加载到Session 1的graph
    sess1 = tf.Session(graph=g1, config=config)  # Session1
    # 加载第一个模型
    with sess1.as_default():
        with g1.as_default():
            with tf.gfile.FastGFile('../models/ocr_angle.pb', 'rb') as f:
                # 使用tf.GraphDef()定义一个空的Graph
                graph_def = tf.GraphDef()
                graph_def.ParseFromString(f.read())
                # Imports the graph from graph_def into the current default Graph.
                tf.import_graph_def(graph_def, name='')
            preprocessing_name1 = 'resnet_v1_50'
            image_preprocessing_fn1 = preprocessing_factory.get_preprocessing(
                preprocessing_name1, is_training=False)
            test_image_size1 = 224
            graph = tf.get_default_graph()
            predictions1 = graph.get_tensor_by_name(
                'resnet_v1_50/predictions/Reshape_1:0')
            tensor_input1 = graph.get_tensor_by_name('Placeholder:0')
            tensor_item1 = tf.placeholder(tf.float32, [None, None, 3])
            processed_image1 = image_preprocessing_fn1(tensor_item1,
                                                       test_image_size1,
                                                       test_image_size1)

    g2 = tf.Graph()
    sess2 = tf.Session(graph=g2, config=config)
    with sess2.as_default():
        with g2.as_default():
            with tf.gfile.FastGFile('../models/pse_invoice.pb', 'rb') as f:
                # 使用tf.GraphDef()定义一个空的Graph
                graph_def = tf.GraphDef()
                graph_def.ParseFromString(f.read())
                # Imports the graph from graph_def into the current default Graph.
                tf.import_graph_def(graph_def, name='')

            graph = tf.get_default_graph()
            input_images2 = graph.get_tensor_by_name('input_images:0')
            seg_maps_pred2 = graph.get_tensor_by_name('Sigmoid:0')

    g3 = tf.Graph()
    sess3 = tf.Session(graph=g3, config=config)
    with sess3.as_default():
        with g3.as_default():
            with tf.gfile.FastGFile('../models/ocr_field.pb', 'rb') as f:
                # 使用tf.GraphDef()定义一个空的Graph
                graph_def = tf.GraphDef()
                graph_def.ParseFromString(f.read())
                # Imports the graph from graph_def into the current default Graph.
                tf.import_graph_def(graph_def, name='')

            preprocessing_name3 = 'resnet_v1_50'
            image_preprocessing_fn3 = preprocessing_factory.get_preprocessing(
                preprocessing_name3, is_training=False)
            output_height3, output_width3 = 100, 400
            graph = tf.get_default_graph()
            predictions3 = graph.get_tensor_by_name(
                'resnet_v1_50/predictions/Reshape_1:0')
            tensor_input3 = graph.get_tensor_by_name('Placeholder:0')
            tensor_item3 = tf.placeholder(tf.float32, [None, None, 3])
            processed_image3 = image_preprocessing_fn3(tensor_item3,
                                                       output_height3,
                                                       output_width3)

    g4 = tf.Graph()
    sess4 = tf.Session(graph=g4, config=config)
    with sess4.as_default():
        with g4.as_default():
            with tf.gfile.FastGFile('../models/ocr_code.pb', 'rb') as f:
                # 使用tf.GraphDef()定义一个空的Graph
                graph_def = tf.GraphDef()
                graph_def.ParseFromString(f.read())
                # Imports the graph from graph_def into the current default Graph.
                tf.import_graph_def(graph_def, name='')

            graph = tf.get_default_graph()
            image_inputs4 = graph.get_tensor_by_name('image_inputs:0')
            logits = graph.get_tensor_by_name('transpose_2:0')
            shape = tf.shape(logits)
            seq_len = tf.reshape(shape[0], [-1])
            seq_len = tf.tile(seq_len, [shape[1]])
            greedy_decoder = tf.nn.ctc_greedy_decoder(logits, seq_len)
            decoded4 = greedy_decoder[0]

    g5 = tf.Graph()
    sess5 = tf.Session(graph=g5, config=config)
    with sess5.as_default():
        with g5.as_default():
            with tf.gfile.FastGFile('../models/ocr_number.pb', 'rb') as f:
                # 使用tf.GraphDef()定义一个空的Graph
                graph_def = tf.GraphDef()
                graph_def.ParseFromString(f.read())
                # Imports the graph from graph_def into the current default Graph.
                tf.import_graph_def(graph_def, name='')

            graph = tf.get_default_graph()
            image_inputs5 = graph.get_tensor_by_name('image_inputs:0')
            logits = graph.get_tensor_by_name('transpose_2:0')
            shape = tf.shape(logits)
            seq_len = tf.reshape(shape[0], [-1])
            seq_len = tf.tile(seq_len, [shape[1]])
            greedy_decoder = tf.nn.ctc_greedy_decoder(logits, seq_len)
            decoded5 = greedy_decoder[0]

    g6 = tf.Graph()
    sess6 = tf.Session(graph=g6, config=config)
    with sess6.as_default():
        with g6.as_default():
            with tf.gfile.FastGFile('../models/ocr_price.pb', 'rb') as f:
                # 使用tf.GraphDef()定义一个空的Graph
                graph_def = tf.GraphDef()
                graph_def.ParseFromString(f.read())
                # Imports the graph from graph_def into the current default Graph.
                tf.import_graph_def(graph_def, name='')

            graph = tf.get_default_graph()
            image_inputs6 = graph.get_tensor_by_name('image_inputs:0')
            logits = graph.get_tensor_by_name('transpose_2:0')
            shape = tf.shape(logits)
            seq_len = tf.reshape(shape[0], [-1])
            seq_len = tf.tile(seq_len, [shape[1]])
            greedy_decoder = tf.nn.ctc_greedy_decoder(logits, seq_len)
            decoded6 = greedy_decoder[0]

    g7 = tf.Graph()
    sess7 = tf.Session(graph=g7, config=config)
    with sess7.as_default():
        with g7.as_default():
            with tf.gfile.FastGFile('../models/ocr_date.pb', 'rb') as f:
                # 使用tf.GraphDef()定义一个空的Graph
                graph_def = tf.GraphDef()
                graph_def.ParseFromString(f.read())
                # Imports the graph from graph_def into the current default Graph.
                tf.import_graph_def(graph_def, name='')

            graph = tf.get_default_graph()
            image_inputs7 = graph.get_tensor_by_name('image_inputs:0')
            logits = graph.get_tensor_by_name('transpose_2:0')
            shape = tf.shape(logits)
            seq_len = tf.reshape(shape[0], [-1])
            seq_len = tf.tile(seq_len, [shape[1]])
            greedy_decoder = tf.nn.ctc_greedy_decoder(logits, seq_len)
            decoded7 = greedy_decoder[0]

    # 使用的时候
    with sess1.as_default():
        with sess1.graph.as_default():
            sess1.run(tf.global_variables_initializer())
            src = cv2.imread(img_path)
            image = cv2.cvtColor(src, cv2.COLOR_BGR2RGB)
            image = cv2.resize(image, (test_image_size1, test_image_size1))
            height, width = image.shape[0], image.shape[1]
            image = np.reshape(image, [height, width, 3])
            image = sess1.run(processed_image1,
                              feed_dict={tensor_item1: image})
            logi = sess1.run(predictions1, feed_dict={tensor_input1:
                                                      [image]})[0]
            prediction = np.argmax(logi)
            degree = labels_angle[prediction]
            print(img_path, degree)

    # 使用的时候
    detected_rects = {}
    with sess2.as_default():
        with sess2.graph.as_default():
            im_src = rotate(src, degree)
            im = copy.deepcopy(im_src)
            im = im[:, :, ::-1]
            # im = cv2.imread(img_path)[:, :, ::-1]
            logger.debug('image file:{}'.format(img_path))
            os.makedirs(output_dir, exist_ok=True)
            start_time = time.time()
            im_resized, (ratio_h, ratio_w) = resize_image(im)
            h, w, _ = im_resized.shape
            timer = {'net': 0, 'pse': 0}
            start = time.time()
            seg_maps = sess2.run(seg_maps_pred2,
                                 feed_dict={input_images2: [im_resized]})
            timer['net'] = time.time() - start

            boxes, kernels, timer = detect(seg_maps=seg_maps,
                                           timer=timer,
                                           image_w=w,
                                           image_h=h)
            logger.info('{} : net {:.0f}ms, pse {:.0f}ms'.format(
                img_path, timer['net'] * 1000, timer['pse'] * 1000))

            if boxes is not None:
                boxes = boxes.reshape((-1, 4, 2))
                boxes[:, :, 0] /= ratio_w
                boxes[:, :, 1] /= ratio_h
                h, w, _ = im.shape
                boxes[:, :, 0] = np.clip(boxes[:, :, 0], 0, w)
                boxes[:, :, 1] = np.clip(boxes[:, :, 1], 0, h)

            duration = time.time() - start_time
            logger.info('[timing] {}'.format(duration))
            if boxes is not None:
                num = 0
                for i in range(len(boxes)):
                    # to avoid submitting errors
                    box = boxes[i]
                    if np.linalg.norm(box[0] - box[1]) < 2 or np.linalg.norm(
                            box[3] - box[0]) < 2:
                        continue
                    cv2.polylines(im[:, :, ::-1],
                                  [box.astype(np.int32).reshape((-1, 1, 2))],
                                  True,
                                  color=(255, 255, 0),
                                  thickness=2)
                    detected_rects[num] = [
                        box[0, 0], box[0, 1], box[1, 0], box[1, 1], box[2, 0],
                        box[2, 1], box[3, 0], box[3, 1]
                    ]
                    num += 1

                # img_text_crop.work(im_src, res_file, output_dir)

            img_path = os.path.join(output_dir, os.path.basename(img_path))
            cv2.imwrite(img_path, im[:, :, ::-1])
            # show_score_geo(im_resized, kernels, im)

    img_items = {}
    img_items_normal = {}
    img_text_crop.work_for_array(im_src, detected_rects, img_items,
                                 img_items_normal)

    ocr_items = {}
    with sess3.as_default():
        with sess3.graph.as_default():
            sess3.run(tf.global_variables_initializer())
            min_value = 5000
            price_index = 0
            for key in img_items_normal.keys():
                image = img_items_normal[key]
                image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
                image = cv2.resize(image, (output_width3, output_height3))
                height, width = image.shape[0], image.shape[1]
                image = np.reshape(image, [height, width, 3])
                image = sess3.run(processed_image3,
                                  feed_dict={tensor_item3: image})
                logi = sess3.run(predictions3,
                                 feed_dict={tensor_input3: [image]})[0]
                prediction = int(np.argmax(logi))
                if prediction < 4:
                    if prediction != 2:
                        ocr_items[prediction] = img_items[key]
                    else:
                        x = int(detected_rects[key][0])
                        if x < min_value:
                            min_value = x
                            price_index = key
            ocr_items[2] = img_items[price_index]

    with sess4.as_default():
        with sess4.graph.as_default():
            spliter4 = ocr_code_utils.OcrSpliter()
            image = spliter4.ocr_split('', ocr_items[0])
            # image = cv2.imread(path, 0)
            imgw = np.asarray(image, dtype=np.float32)
            image = (imgw - np.mean(imgw)) / np.std(imgw)
            image = np.reshape(image, [image.shape[0], image.shape[1], 1])
            feed = {image_inputs4: [image]}
            predictions_result = sess4.run(decoded4[0], feed_dict=feed)
            predictions_result = sparse_tuple_to(predictions_result,
                                                 max_sequence_length_code)
            predictions_result = predictions_result[0]
            predictions_result = [
                labels_code[label] for label in predictions_result
            ]
            pred = ''
            for ch in predictions_result:
                if ch not in ['<BLANK>']:
                    pred += ch
            print('invoice_code:', pred)

    with sess5.as_default():
        with sess5.graph.as_default():
            spliter5 = ocr_number_utils.OcrSpliter()
            image = spliter5.ocr_split('', ocr_items[1])
            # image = cv2.imread(path, 0)
            imgw = np.asarray(image, dtype=np.float32)
            image = (imgw - np.mean(imgw)) / np.std(imgw)
            image = np.reshape(image, [image.shape[0], image.shape[1], 1])
            feed = {image_inputs5: [image]}
            predictions_result = sess5.run(decoded5[0], feed_dict=feed)
            predictions_result = sparse_tuple_to(predictions_result,
                                                 max_sequence_length_number)
            predictions_result = predictions_result[0]
            predictions_result = [
                labels_number[label] for label in predictions_result
            ]
            pred = ''
            for ch in predictions_result:
                if ch not in ['<BLANK>']:
                    pred += ch
            print('invoice_number:', pred)

    with sess6.as_default():
        with sess6.graph.as_default():
            spliter6 = ocr_price_utils.OcrSpliter()
            image = spliter6.ocr_split('', ocr_items[2])
            # image = cv2.imread(path, 0)
            imgw = np.asarray(image, dtype=np.float32)
            image = (imgw - np.mean(imgw)) / np.std(imgw)
            image = np.reshape(image, [image.shape[0], image.shape[1], 1])
            feed = {image_inputs6: [image]}
            predictions_result = sess6.run(decoded6[0], feed_dict=feed)
            predictions_result = sparse_tuple_to(predictions_result,
                                                 max_sequence_length_price)
            predictions_result = predictions_result[0]
            predictions_result = [
                labels_price[label] for label in predictions_result
            ]
            pred = ''
            for ch in predictions_result:
                if ch not in ['<BLANK>']:
                    pred += ch
            pred = pred[:-2] + '.' + pred[-2:]
            print('invoice_price:', pred)

    with sess7.as_default():
        with sess7.graph.as_default():
            spliter7 = ocr_date_utils.OcrSpliter()
            image = spliter7.ocr_split('', ocr_items[3])
            # image = cv2.imread(path, 0)
            imgw = np.asarray(image, dtype=np.float32)
            image = (imgw - np.mean(imgw)) / np.std(imgw)
            image = np.reshape(image, [image.shape[0], image.shape[1], 1])
            feed = {image_inputs7: [image]}
            predictions_result = sess7.run(decoded7[0], feed_dict=feed)
            predictions_result = sparse_tuple_to(predictions_result,
                                                 max_sequence_length_date)
            predictions_result = predictions_result[0]
            predictions_result = [
                labels_date[label] for label in predictions_result
            ]
            pred = ''
            for ch in predictions_result:
                if ch not in ['<BLANK>']:
                    pred += ch
            print('invoice_date:', pred)

    # 关闭sess
    sess1.close()
    sess2.close()
    sess3.close()
    sess4.close()
    sess5.close()
    sess6.close()
    sess7.close()