def _build_signature(graph: tf.Graph, inputs: dict, info: dict) -> util.SignatureDef: """ Return tensor info for inputs and outputs. Args: graph: TF Graph instance inputs: dict that maps input names to tensor information info: dict containing the key 'outputs' that maps to flat list of output names and an optional key 'method_name' Returns: dict that maps inputs and outputs to tensor information (name, type, and shape) """ def _info(name): return util.NodeInfo(name=name, shape=(), dtype=0, tensor=name + ':0') output_names = info[SIGNATURE_OUTPUTS] outputs = { name: util._build_tensor_info(graph, _info(name)) for name in output_names } name = tf.saved_model.PREDICT_METHOD_NAME if SIGNATURE_METHOD in info: name = info[SIGNATURE_METHOD] or name return util.SignatureDef(inputs=inputs, outputs=outputs, method_name=name)
def apply(self, signature: util.SignatureDef) -> util.SignatureDef: """Apply the rename mapping to a SignatureDef""" if not isinstance(signature, util.SignatureDef): raise ValueError('signature must be a SignatureDef proto') # replace the contents of a mapfield with another mapfield def _replace_contents(target, replacement): keys = [key for key in target] # clear the target for key in keys: del target[key] # copy the contents for key in replacement: target[key].CopyFrom(replacement[key]) updated = signature if any(self.mapping): # operate on a local copy such that the order of replacements # doesn't matter (i.e. allow for A=B & B=A name swapping) inputs, outputs = signature.inputs, signature.outputs temp = util.SignatureDef(inputs=inputs, outputs=outputs) for old_name, new_name in self.mapping.items(): if old_name in inputs: temp.inputs[new_name].CopyFrom(inputs[old_name]) del temp.inputs[old_name] elif old_name in outputs: temp.outputs[new_name].CopyFrom(outputs[old_name]) del temp.outputs[old_name] _replace_contents(inputs, temp.inputs) _replace_contents(outputs, temp.outputs) return updated
def _extract_signature_def( model_json: Dict[str, Any]) -> Optional[util.SignatureDef]: """ Extract the signature definition from the model's meta data. Args: model_json: JSON dict from TFJS model file Returns: TF SignatureDef proto; None if meta data is missing or incomplete """ # three possible scenarios: # 1. meta data contains a valid signature w/ inputs and outputs # 2. meta data contains incomplete signature (missing in- or outputs) # 3. meta data is missing or doesn't contain signature # this function works for scenario 1) if USER_DEFINED_METADATA_KEY not in model_json: return None meta_data = model_json[USER_DEFINED_METADATA_KEY] if SIGNATURE_KEY not in meta_data: return None signature = meta_data[SIGNATURE_KEY] if tf.saved_model.PREDICT_INPUTS not in signature: return None if tf.saved_model.PREDICT_OUTPUTS not in signature: return None signature_def = ParseDict(signature, util.SignatureDef()) if len(signature_def.method_name) == 0: signature_def.method_name = tf.saved_model.PREDICT_METHOD_NAME def _remove_channel_from_key(mapfield): names_to_fix = [key for key in mapfield if key.endswith(':0')] for name in names_to_fix: mapfield[name[0:-2]].CopyFrom(mapfield[name]) del mapfield[name] _remove_channel_from_key(signature_def.inputs) _remove_channel_from_key(signature_def.outputs) return signature_def