예제 #1
0
def main(_):
    checkpoint_dir = os.path.join(FLAGS.exp_dir, 'log')
    # eval_dir = os.path.join(FLAGS.exp_dir, 'log/eval')
    model_config, _, _ = get_configs_from_exp_dir()

    model = model_builder.build(model_config, is_training=False)

    input_image_str_tensor = tf.placeholder(dtype=tf.string, shape=[])
    input_image_tensor = tf.image.decode_jpeg(
        input_image_str_tensor,
        channels=3,
    )
    resized_image_tensor = tf.image.resize_images(
        tf.to_float(input_image_tensor), [64, 256])

    predictions_dict = model.predict(tf.expand_dims(resized_image_tensor, 0))
    recognitions = model.postprocess(predictions_dict)
    recognition_text = recognitions['text'][0]
    control_points = predictions_dict['control_points'],
    rectified_images = predictions_dict['rectified_images']

    saver = tf.train.Saver(tf.global_variables())
    checkpoint = os.path.join(FLAGS.exp_dir, 'log/model.ckpt')

    fetches = {
        'original_image': input_image_tensor,
        'recognition_text': recognition_text,
        'control_points': predictions_dict['control_points'],
        'rectified_images': predictions_dict['rectified_images'],
    }

    with open(FLAGS.input_image, 'rb') as f:
        input_image_str = f.read()

    with tf.Session() as sess:
        sess.run([
            tf.global_variables_initializer(),
            tf.local_variables_initializer(),
            tf.tables_initializer()
        ])
        saver.restore(sess, checkpoint)
        sess_outputs = sess.run(
            fetches, feed_dict={input_image_str_tensor: input_image_str})

    print('Recognized text: {}'.format(
        sess_outputs['recognition_text'].decode('utf-8')))

    rectified_image = sess_outputs['rectified_images'][0]
    rectified_image_pil = Image.fromarray(
        (128 * (rectified_image + 1.0)).astype(np.uint8))
    input_image_dir = os.path.dirname(FLAGS.input_image)
    rectified_image_save_path = os.path.join(input_image_dir,
                                             'rectified_image.jpg')
    rectified_image_pil.save(rectified_image_save_path)
    print('Rectified image saved to {}'.format(rectified_image_save_path))
예제 #2
0
 def test_single_predictor_model_training(self):
     model_proto = model_pb2.Model()
     text_format.Merge(SINGLE_PREDICTOR_MODEL_TEXT_PROTO, model_proto)
     model_object = model_builder.build(model_proto, True)
     test_groundtruth_text_list = [
         tf.constant(b'hello', dtype=tf.string),
         tf.constant(b'world', dtype=tf.string)
     ]
     model_object.provide_groundtruth(
         {'groundtruth_text': test_groundtruth_text_list})
     test_input_image = tf.random_uniform(shape=[2, 32, 100, 3],
                                          minval=0,
                                          maxval=255,
                                          dtype=tf.float32,
                                          seed=1)
     prediction_dict = model_object.predict(
         model_object.preprocess(test_input_image))
     loss = model_object.loss(prediction_dict)
     with self.test_session() as sess:
         sess.run(
             [tf.global_variables_initializer(),
              tf.tables_initializer()])
         outputs = sess.run({'loss': loss})
         print(outputs['loss'])
예제 #3
0
 def test_stn_multi_predictor_model_inference(self):
     model_proto = model_pb2.Model()
     text_format.Merge(STN_MULTIPLE_PREDICTOR_MODEL_TEXT_PROTO, model_proto)
     model_object = model_builder.build(model_proto, False)
     test_groundtruth_text_list = [
         tf.constant(b'hello', dtype=tf.string),
         tf.constant(b'world', dtype=tf.string)
     ]
     model_object.provide_groundtruth(
         {'groundtruth_text': test_groundtruth_text_list})
     test_input_image = tf.random_uniform(shape=[2, 32, 100, 3],
                                          minval=0,
                                          maxval=255,
                                          dtype=tf.float32,
                                          seed=1)
     prediction_dict = model_object.predict(
         model_object.preprocess(test_input_image))
     recognition_dict = model_object.postprocess(prediction_dict)
     with self.test_session() as sess:
         sess.run(
             [tf.global_variables_initializer(),
              tf.tables_initializer()])
         outputs = sess.run(recognition_dict)
         print(outputs)
예제 #4
0
    def test_build_attention_model_single_branch(self):
        model_text_proto = """
    attention_recognition_model {
      feature_extractor {
        convnet {
          crnn_net {
            net_type: SINGLE_BRANCH
            conv_hyperparams {
              op: CONV
              regularizer { l2_regularizer { weight: 1e-4 } }
              initializer { variance_scaling_initializer { } }
              batch_norm { }
            }
            summarize_activations: false
          }
        }
        bidirectional_rnn {
          fw_bw_rnn_cell {
            lstm_cell {
              num_units: 256
              forget_bias: 1.0
              initializer { orthogonal_initializer {} }
            }
          }
          rnn_regularizer { l2_regularizer { weight: 1e-4 } }
          num_output_units: 256
          fc_hyperparams {
            op: FC
            activation: RELU
            initializer { variance_scaling_initializer { } }
            regularizer { l2_regularizer { weight: 1e-4 } }
          }
        }
        summarize_activations: true
      }

      predictor {
        name: "ForwardPredictor"
        bahdanau_attention_predictor {
          reverse: false
          rnn_cell {
            lstm_cell {
              num_units: 256
              forget_bias: 1.0
              initializer { orthogonal_initializer { } }
            }
          }
          rnn_regularizer { l2_regularizer { weight: 1e-4 } }
          num_attention_units: 128
          max_num_steps: 10
          multi_attention: false
          beam_width: 1
          reverse: false
          label_map {
            character_set {
              text_string: "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
              delimiter: ""
            }
            label_offset: 2
          }
          loss {
            sequence_cross_entropy_loss {
              sequence_normalize: false
              sample_normalize: true
            }
          }
        }
      }
    }
    """
        model_proto = model_pb2.Model()
        text_format.Merge(model_text_proto, model_proto)
        model_object = model_builder.build(model_proto, True)

        test_groundtruth_text_list = [
            tf.constant(b'hello', dtype=tf.string),
            tf.constant(b'world', dtype=tf.string)
        ]
        model_object.provide_groundtruth(test_groundtruth_text_list)
        test_input_image = tf.random_uniform(shape=[2, 32, 100, 3],
                                             minval=0,
                                             maxval=255,
                                             dtype=tf.float32,
                                             seed=1)
        prediction_dict = model_object.predict(
            model_object.preprocess(test_input_image))
        loss = model_object.loss(prediction_dict)

        with self.test_session() as sess:
            sess.run(
                [tf.global_variables_initializer(),
                 tf.tables_initializer()])
            outputs = sess.run({'loss': loss})
            print(outputs['loss'])
예제 #5
0
파일: demo.py 프로젝트: sungjune-p/aster
def main(_):
    checkpoint_dir = os.path.join(FLAGS.exp_dir, 'log')
    # eval_dir = os.path.join(FLAGS.exp_dir, 'log/eval')
    model_config, _, _ = get_configs_from_exp_dir()

    model = model_builder.build(model_config, is_training=False)

    input_image_str_tensor = tf.placeholder(dtype=tf.string, shape=[])
    input_image_tensor = tf.image.decode_jpeg(
        input_image_str_tensor,
        channels=3,
    )
    resized_image_tensor = tf.image.resize_images(
        tf.to_float(input_image_tensor), [64, 256])

    predictions_dict = model.predict(tf.expand_dims(resized_image_tensor, 0))
    recognitions = model.postprocess(predictions_dict)
    recognition_text = recognitions['text'][0]
    control_points = predictions_dict['control_points'],
    rectified_images = predictions_dict['rectified_images']

    saver = tf.train.Saver(tf.global_variables())
    checkpoint = os.path.join(FLAGS.exp_dir, 'log/model.ckpt')

    fetches = {
        'original_image': input_image_tensor,
        'recognition_text': recognition_text,
        'control_points': predictions_dict['control_points'],
        'rectified_images': predictions_dict['rectified_images'],
    }

    with tf.Session() as sess:
        sess.run([
            tf.global_variables_initializer(),
            tf.local_variables_initializer(),
            tf.tables_initializer()
        ])
        saver.restore(sess, checkpoint)

        image_list = os.listdir(FLAGS.data_dir)
        image_list.sort()
        vid_num = '00000'
        for file in image_list:
            with open(os.path.join(FLAGS.data_dir, file), 'rb') as f:
                input_image_str = f.read()

            # Read .tsv file to get frame info
            if vid_num != file.split('_')[0][4:]:
                tsv_path = os.path.join(FLAGS.tsv_dir,
                                        file.split('_')[0][4:] + '.tsv')
                with open(tsv_path) as tsv_file:
                    tsv_reader = csv.reader(tsv_file, delimiter="\t")
                    frame_data = []
                    i = 0
                    for line in tsv_reader:
                        if i == 0:
                            pass
                        else:
                            frame_data.append(line)
                        i += 1

            sess_outputs = sess.run(
                fetches, feed_dict={input_image_str_tensor: input_image_str})
            text = sess_outputs['recognition_text'].decode('utf-8')

            print(
                'Recognized text for ', file, ' : {}'.format(
                    sess_outputs['recognition_text'].decode('utf-8')))

            # rectified_image = sess_outputs['rectified_images'][0]
            # rectified_image_pil = Image.fromarray((128 * (rectified_image + 1.0)).astype(np.uint8))
            # input_image_dir = os.path.dirname(FLAGS.input_image)
            # rectified_image_save_path = os.path.join(FLAGS.data_dir, 'rectifed %s' %file)
            # rectified_image_pil.save(rectified_image_save_path)
            # print('Rectified image saved to {}'.format(rectified_image_save_path))
            # print('Check Video Number : ', file.split('_')[0])
            print('Check Image Name : ',
                  '_'.join([file.split('_')[0],
                            file.split('_')[1]]))
            video_number = int(file.split('_')[0][4:])
            key_frame_num = int(file.split('_')[1])
            start_frame = int(frame_data[key_frame_num - 1][0])
            end_frame = int(frame_data[key_frame_num - 1][2])
            start_time = float(frame_data[key_frame_num - 1][1])
            end_time = float(frame_data[key_frame_num - 1][3])
            # frame_seg = frame_data[key_frame_num-1][0]+'-'+frame_data[key_frame_num-1][2]
            # video_name_and_frame = '_'.join([file.split('_')[0][4:], frame_seg])

            col.update({
                "video": video_number,
                "startFrame": start_frame
            }, {
                "$set": {
                    "video": video_number,
                    "startFrame": start_frame,
                    "endFrame": end_frame,
                    "startSecond": start_time,
                    "endSecond": end_time
                },
                "$addToSet": {
                    "text": text
                }
            },
                       upsert=True)

            vid_num = file.split('_')[0][4:]