Example #1
0
def _build_signature(graph: tf.Graph, inputs: dict,
                     info: dict) -> util.SignatureDef:
    """
    Return tensor info for inputs and outputs.

    Args:
        graph: TF Graph instance
        inputs: dict that maps input names to tensor information
        info: dict containing the key 'outputs' that maps to flat list of
            output names and an optional key 'method_name'

    Returns:
        dict that maps inputs and outputs to tensor information
        (name, type, and shape)
    """
    def _info(name):
        return util.NodeInfo(name=name, shape=(), dtype=0, tensor=name + ':0')

    output_names = info[SIGNATURE_OUTPUTS]
    outputs = {
        name: util._build_tensor_info(graph, _info(name))
        for name in output_names
    }
    name = tf.saved_model.PREDICT_METHOD_NAME
    if SIGNATURE_METHOD in info:
        name = info[SIGNATURE_METHOD] or name
    return util.SignatureDef(inputs=inputs, outputs=outputs, method_name=name)
Example #2
0
    def apply(self, signature: util.SignatureDef) -> util.SignatureDef:
        """Apply the rename mapping to a SignatureDef"""
        if not isinstance(signature, util.SignatureDef):
            raise ValueError('signature must be a SignatureDef proto')

        # replace the contents of a mapfield with another mapfield
        def _replace_contents(target, replacement):
            keys = [key for key in target]
            # clear the target
            for key in keys:
                del target[key]
            # copy the contents
            for key in replacement:
                target[key].CopyFrom(replacement[key])

        updated = signature
        if any(self.mapping):
            # operate on a local copy such that the order of replacements
            # doesn't matter (i.e. allow for A=B & B=A name swapping)
            inputs, outputs = signature.inputs, signature.outputs
            temp = util.SignatureDef(inputs=inputs, outputs=outputs)
            for old_name, new_name in self.mapping.items():
                if old_name in inputs:
                    temp.inputs[new_name].CopyFrom(inputs[old_name])
                    del temp.inputs[old_name]
                elif old_name in outputs:
                    temp.outputs[new_name].CopyFrom(outputs[old_name])
                    del temp.outputs[old_name]
            _replace_contents(inputs, temp.inputs)
            _replace_contents(outputs, temp.outputs)
        return updated
Example #3
0
def _extract_signature_def(
        model_json: Dict[str, Any]) -> Optional[util.SignatureDef]:
    """
    Extract the signature definition from the model's meta data.

    Args:
        model_json: JSON dict from TFJS model file

    Returns:
        TF SignatureDef proto; None if meta data is missing or incomplete
    """
    # three possible scenarios:
    #   1. meta data contains a valid signature w/ inputs and outputs
    #   2. meta data contains incomplete signature (missing in- or outputs)
    #   3. meta data is missing or doesn't contain signature
    # this function works for scenario 1)
    if USER_DEFINED_METADATA_KEY not in model_json:
        return None
    meta_data = model_json[USER_DEFINED_METADATA_KEY]
    if SIGNATURE_KEY not in meta_data:
        return None
    signature = meta_data[SIGNATURE_KEY]
    if tf.saved_model.PREDICT_INPUTS not in signature:
        return None
    if tf.saved_model.PREDICT_OUTPUTS not in signature:
        return None
    signature_def = ParseDict(signature, util.SignatureDef())
    if len(signature_def.method_name) == 0:
        signature_def.method_name = tf.saved_model.PREDICT_METHOD_NAME

    def _remove_channel_from_key(mapfield):
        names_to_fix = [key for key in mapfield if key.endswith(':0')]
        for name in names_to_fix:
            mapfield[name[0:-2]].CopyFrom(mapfield[name])
            del mapfield[name]

    _remove_channel_from_key(signature_def.inputs)
    _remove_channel_from_key(signature_def.outputs)
    return signature_def