def _build_multi_predictors_recognition_model(config, is_training): if not isinstance(config, model_pb2.MultiPredictorsRecognitionModel): raise ValueError('config not of type model_pb2.MultiPredictorsRecognitionModel') spatial_transformer_object = None if config.HasField('spatial_transformer'): spatial_transformer_object = spatial_transformer_builder.build( config.spatial_transformer, is_training) feature_extractor_object = feature_extractor_builder.build( config.feature_extractor, is_training=is_training ) predictors_dict = { predictor_config.name: predictor_builder.build(predictor_config, is_training=is_training) for predictor_config in config.predictor } regression_loss_object = ( None if not config.keypoint_supervision else loss_builder.build(config.regression_loss)) model_object = multi_predictors_recognition_model.MultiPredictorsRecognitionModel( spatial_transformer=spatial_transformer_object, feature_extractor=feature_extractor_object, predictors_dict=predictors_dict, keypoint_supervision=config.keypoint_supervision, regression_loss=regression_loss_object, is_training=is_training, ) return model_object
def test_sync_predictor_builder(self): predictor_text_proto = """ attention_predictor { rnn_cell { lstm_cell { num_units: 256 forget_bias: 1.0 initializer { orthogonal_initializer { } } } } rnn_regularizer { l2_regularizer { weight: 1e-4 } } num_attention_units: 128 max_num_steps: 10 multi_attention: false beam_width: 1 reverse: false label_map { character_set { text_string: "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" delimiter: "" } label_offset: 2 } loss { sequence_cross_entropy_loss { sequence_normalize: false sample_normalize: true } } sync: true } """ predictor_proto = predictor_pb2.Predictor() text_format.Merge(predictor_text_proto, predictor_proto) predictor_object = predictor_builder.build(predictor_proto, True) feature_maps = [tf.random_uniform([2, 1, 10, 32], dtype=tf.float32)] predictor_object.provide_groundtruth( tf.constant([b'hello', b'world'], dtype=tf.string)) predictions_dict = predictor_object.predict(feature_maps) loss = predictor_object.loss(predictions_dict) with self.test_session() as sess: sess.run( [tf.global_variables_initializer(), tf.tables_initializer()]) sess_outputs = sess.run({'loss': loss}) print(sess_outputs)