コード例 #1
0
 def test_stn_multi_predictor_model_inference(self):
     model_proto = model_pb2.Model()
     text_format.Merge(STN_MULTIPLE_PREDICTOR_MODEL_TEXT_PROTO, model_proto)
     model_object = model_builder.build(model_proto, False)
     test_groundtruth_text_list = [
         tf.constant(b'hello', dtype=tf.string),
         tf.constant(b'world', dtype=tf.string)
     ]
     model_object.provide_groundtruth(
         {'groundtruth_text': test_groundtruth_text_list})
     test_input_image = tf.random_uniform(shape=[2, 32, 100, 3],
                                          minval=0,
                                          maxval=255,
                                          dtype=tf.float32,
                                          seed=1)
     prediction_dict = model_object.predict(
         model_object.preprocess(test_input_image))
     recognition_dict = model_object.postprocess(prediction_dict)
     with self.test_session() as sess:
         sess.run(
             [tf.global_variables_initializer(),
              tf.tables_initializer()])
         outputs = sess.run(recognition_dict)
         print(outputs)
コード例 #2
0
 def test_single_predictor_model_training(self):
     model_proto = model_pb2.Model()
     text_format.Merge(SINGLE_PREDICTOR_MODEL_TEXT_PROTO, model_proto)
     model_object = model_builder.build(model_proto, True)
     test_groundtruth_text_list = [
         tf.constant(b'hello', dtype=tf.string),
         tf.constant(b'world', dtype=tf.string)
     ]
     model_object.provide_groundtruth(
         {'groundtruth_text': test_groundtruth_text_list})
     test_input_image = tf.random_uniform(shape=[2, 32, 100, 3],
                                          minval=0,
                                          maxval=255,
                                          dtype=tf.float32,
                                          seed=1)
     prediction_dict = model_object.predict(
         model_object.preprocess(test_input_image))
     loss = model_object.loss(prediction_dict)
     with self.test_session() as sess:
         sess.run(
             [tf.global_variables_initializer(),
              tf.tables_initializer()])
         outputs = sess.run({'loss': loss})
         print(outputs['loss'])
コード例 #3
0
    def test_build_attention_model_single_branch(self):
        model_text_proto = """
    attention_recognition_model {
      feature_extractor {
        convnet {
          crnn_net {
            net_type: SINGLE_BRANCH
            conv_hyperparams {
              op: CONV
              regularizer { l2_regularizer { weight: 1e-4 } }
              initializer { variance_scaling_initializer { } }
              batch_norm { }
            }
            summarize_activations: false
          }
        }
        bidirectional_rnn {
          fw_bw_rnn_cell {
            lstm_cell {
              num_units: 256
              forget_bias: 1.0
              initializer { orthogonal_initializer {} }
            }
          }
          rnn_regularizer { l2_regularizer { weight: 1e-4 } }
          num_output_units: 256
          fc_hyperparams {
            op: FC
            activation: RELU
            initializer { variance_scaling_initializer { } }
            regularizer { l2_regularizer { weight: 1e-4 } }
          }
        }
        summarize_activations: true
      }

      predictor {
        name: "ForwardPredictor"
        bahdanau_attention_predictor {
          reverse: false
          rnn_cell {
            lstm_cell {
              num_units: 256
              forget_bias: 1.0
              initializer { orthogonal_initializer { } }
            }
          }
          rnn_regularizer { l2_regularizer { weight: 1e-4 } }
          num_attention_units: 128
          max_num_steps: 10
          multi_attention: false
          beam_width: 1
          reverse: false
          label_map {
            character_set {
              text_string: "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
              delimiter: ""
            }
            label_offset: 2
          }
          loss {
            sequence_cross_entropy_loss {
              sequence_normalize: false
              sample_normalize: true
            }
          }
        }
      }
    }
    """
        model_proto = model_pb2.Model()
        text_format.Merge(model_text_proto, model_proto)
        model_object = model_builder.build(model_proto, True)

        test_groundtruth_text_list = [
            tf.constant(b'hello', dtype=tf.string),
            tf.constant(b'world', dtype=tf.string)
        ]
        model_object.provide_groundtruth(test_groundtruth_text_list)
        test_input_image = tf.random_uniform(shape=[2, 32, 100, 3],
                                             minval=0,
                                             maxval=255,
                                             dtype=tf.float32,
                                             seed=1)
        prediction_dict = model_object.predict(
            model_object.preprocess(test_input_image))
        loss = model_object.loss(prediction_dict)

        with self.test_session() as sess:
            sess.run(
                [tf.global_variables_initializer(),
                 tf.tables_initializer()])
            outputs = sess.run({'loss': loss})
            print(outputs['loss'])