def _create_signature(model, class_names: List[str], top_k: Optional[int]): serialized_tf_example = array_ops.placeholder(tf.string, name="tf_example") feature_configs = {"x": tf.io.FixedLenFeature([], tf.string)} tf_example = tf.io.parse_example(serialized_tf_example, feature_configs) jpegs = tf_example["x"] x = tf.map_fn(_preprocess_image, jpegs, dtype=tf.float32) y = model(x) top_k = min(top_k or len(class_names), len(class_names)) values, indices = tf.nn.top_k(y, top_k) table_class_names = lookup_ops.index_to_string_table_from_tensor( vocabulary_list=tf.constant(class_names), default_value="UNK", name=None) classification_inputs = build_tensor_info(serialized_tf_example) prediction_class_names = table_class_names.lookup( tf.cast(indices, dtype=dtypes.int64)) classification_outputs_class_names = build_tensor_info( prediction_class_names) classification_outputs_scores = build_tensor_info(values) classification_signature = build_signature_def( inputs={sig_consts_v1.CLASSIFY_INPUTS: classification_inputs}, outputs={ sig_consts_v1.CLASSIFY_OUTPUT_CLASSES: classification_outputs_class_names, sig_consts_v1.CLASSIFY_OUTPUT_SCORES: classification_outputs_scores, }, method_name=sig_consts_v1.CLASSIFY_METHOD_NAME, ) # Ensure valid signature if not is_valid_signature(classification_signature): raise ValueError("Invalid classification signature!") return classification_signature
def _assertInvalidSignature(self, inputs, outputs, method_name): signature_def = signature_def_utils_impl.build_signature_def( inputs, outputs, method_name) self.assertFalse( signature_def_utils_impl.is_valid_signature(signature_def))
def _assertInvalidSignature(self, inputs, outputs, method_name): signature_def = signature_def_utils_impl.build_signature_def( inputs, outputs, method_name) self.assertFalse( signature_def_utils_impl.is_valid_signature(signature_def))