예제 #1
0
def _build_multi_predictors_recognition_model(config, is_training):
    if not isinstance(config, model_pb2.MultiPredictorsRecognitionModel):
        raise ValueError(
            'config not of type model_pb2.MultiPredictorsRecognitionModel')

    spatial_transformer_object = None
    if config.HasField('spatial_transformer'):
        spatial_transformer_object = spatial_transformer_builder.build(
            config.spatial_transformer, is_training)

    feature_extractor_object = feature_extractor_builder.build(
        config.feature_extractor, is_training=is_training)
    predictors_dict = {
        predictor_config.name: predictor_builder.build(predictor_config,
                                                       is_training=is_training)
        for predictor_config in config.predictor
    }
    regression_loss_object = (None if not config.keypoint_supervision else
                              loss_builder.build(config.regression_loss))

    model_object = multi_predictors_recognition_model.MultiPredictorsRecognitionModel(
        spatial_transformer=spatial_transformer_object,
        feature_extractor=feature_extractor_object,
        predictors_dict=predictors_dict,
        keypoint_supervision=config.keypoint_supervision,
        regression_loss=regression_loss_object,
        is_training=is_training,
    )
    return model_object
예제 #2
0
    def test_sync_predictor_builder(self):
        predictor_text_proto = """
    attention_predictor {
      rnn_cell {
        lstm_cell {
          num_units: 256
          forget_bias: 1.0
          initializer { orthogonal_initializer { } }
        }
      }
      rnn_regularizer { l2_regularizer { weight: 1e-4 } }
      num_attention_units: 128
      max_num_steps: 10
      multi_attention: false
      beam_width: 1
      reverse: false
      label_map {
        character_set {
          text_string: "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
          delimiter: ""
        }
        label_offset: 2
      }
      loss {
        sequence_cross_entropy_loss {
          sequence_normalize: false
          sample_normalize: true
        }
      }
      sync: true
    }
    """
        predictor_proto = predictor_pb2.Predictor()
        text_format.Merge(predictor_text_proto, predictor_proto)
        predictor_object = predictor_builder.build(predictor_proto, True)

        feature_maps = [tf.random_uniform([2, 1, 10, 32], dtype=tf.float32)]
        predictor_object.provide_groundtruth(
            tf.constant([b'hello', b'world'], dtype=tf.string))
        predictions_dict = predictor_object.predict(feature_maps)
        loss = predictor_object.loss(predictions_dict)

        with self.test_session() as sess:
            sess.run(
                [tf.global_variables_initializer(),
                 tf.tables_initializer()])
            sess_outputs = sess.run({'loss': loss})
            print(sess_outputs)