Beispiel #1
0
def _from_proto_sparse_tensor(sparse_tensor_proto, process_leafs):
    """Deserializes a `tf.SparseTensor` from `sparse_tensor_proto`.

    Args:
      sparse_tensor_proto: A proto representing a `tf.SparseTensor`.
      process_leafs: A function to be applied to the leaf valued of the nested
        structure.

    Returns:
      An instance of `tf.SparseTensor`.
    """
    if not sparse_tensor_proto.HasField("named_tuple"):
        raise base_errors.ModuleInfoError(
            "Error while deserializing a SparseTensor: expected proto tuple.")
    if sparse_tensor_proto.named_tuple.name != _SPARSE_TENSOR_NAME:
        raise base_errors.ModuleInfoError(
            "Error while deserializing a SparseTensor: The name of the tuple "
            "should have been {} but was {}.".format(
                _SPARSE_TENSOR_NAME, sparse_tensor_proto.named_tuple.name))
    named_tuple_map = sparse_tensor_proto.named_tuple.map
    return tf.SparseTensor(
        indices=process_leafs(named_tuple_map["indices"].value),
        values=process_leafs(named_tuple_map["values"].value),
        dense_shape=process_leafs(named_tuple_map["dense_shape"].value),
    )
Beispiel #2
0
def _nested_from_proto(nested_proto, process_leafs):
    """Deserializes `nested_proto`.

  Args:
    nested_proto: An instance of `module_pb2.NestedData`.
    process_leafs: A function to be applied to the leaf values of the nested
      structure.

  Returns:
    An instance of `string`, `tuple`, `dict` or `namedtuple`.

  Raises:
    base_errors.ModuleInfoError: If the probobuf is of the wrong type or
      if some of its fields are missing.
  """
    if not isinstance(nested_proto, module_pb2.NestedData):
        raise base_errors.ModuleInfoError("Expected module_pb2.NestedData.")

    if nested_proto.HasField("value"):
        value = nested_proto.value
        if not value:
            value = _UnserializableObject()
        else:
            value = process_leafs(value)
        return value
    elif nested_proto.HasField("list"):
        return [
            _nested_from_proto(child, process_leafs)
            for child in nested_proto.list.list
        ]
    elif nested_proto.HasField("tuple"):
        return tuple(
            _nested_from_proto(child, process_leafs)
            for child in nested_proto.tuple.list)
    elif nested_proto.HasField("dict"):
        return {
            name: _nested_from_proto(child, process_leafs)
            for name, child in six.iteritems(nested_proto.dict.map)
        }
    elif nested_proto.HasField("named_tuple"):
        tmp_dict = {
            name: _nested_from_proto(child, process_leafs)
            for name, child in six.iteritems(nested_proto.named_tuple.map)
        }
        # Note that this needs to be a named tuple to work with existing usage.
        NamedTuple = collections.namedtuple(  # pylint: disable=invalid-name
            nested_proto.named_tuple.name, tmp_dict.keys())
        return NamedTuple(**tmp_dict)
    elif nested_proto.HasField("special_type"):
        if nested_proto.special_type.name not in _TO_PROTO_SPECIAL_TYPES:
            return _UnserializableObject()
        type_info = _TO_PROTO_SPECIAL_TYPES[nested_proto.special_type.name]
        return type_info.from_proto(nested_proto.special_type.object,
                                    process_leafs)
    else:
        raise base_errors.ModuleInfoError(
            "Cannot deserialize a `ModuleInfo` protobuf with no fields.")
Beispiel #3
0
def _nested_to_proto(nested_value, nested_proto, process_leafs,
                     already_processed):
    """Serializes `nested_value` into `nested_proto`.

  Args:
    nested_value: A nested Python value.
    nested_proto: A `module_pb2.NestedData` instance to be filled from the value
      in `nested_value`.
    process_leafs: A function to be applied to the leaf values of the nested
      structure.
    already_processed: Set of already processed objects (used to avoid
      infinite recursion).
  Raises:
    ModuleInfoError: If `nested_proto` is not an instance of
      `module_pb2.NestedData`.
  """
    if not isinstance(nested_proto, module_pb2.NestedData):
        raise base_errors.ModuleInfoError("Expected module_pb2.NestedData.")

    # If this object was already processed, mark as "unserializable"
    # to avoid infinite recursion.
    if id(nested_value) in already_processed:
        nested_proto.value = ""
        return

    # Check special types.
    for type_name, type_info in six.iteritems(_TO_PROTO_SPECIAL_TYPES):
        if type_info.check(nested_value):
            nested_proto.special_type.name = type_name
            type_info.to_proto(nested_value, nested_proto.special_type.object,
                               process_leafs, already_processed)
            return

    # Check standard types.
    if _is_iterable(nested_value):
        # Mark this container as "already processed" to avoid infinite recursion.
        already_processed.add(id(nested_value))
        if isinstance(nested_value, dict):
            nested_proto.dict.SetInParent()
            for key, child in six.iteritems(nested_value):
                str_key = str(key)
                child_proto = nested_proto.dict.map[str_key]
                _nested_to_proto(child, child_proto, process_leafs,
                                 already_processed)
        elif isinstance(nested_value, tuple):
            # NamedTuple?
            if _is_namedtuple(nested_value):
                nested_proto.named_tuple.name = type(nested_value).__name__
                for str_key in nested_value._fields:
                    child = getattr(nested_value, str_key)
                    child_proto = nested_proto.named_tuple.map[str_key]
                    _nested_to_proto(child, child_proto, process_leafs,
                                     already_processed)
            else:
                nested_proto.tuple.SetInParent()
                for child in nested_value:
                    child_proto = nested_proto.tuple.list.add()
                    _nested_to_proto(child, child_proto, process_leafs,
                                     already_processed)
        else:
            nested_proto.list.SetInParent()
            for child in nested_value:
                child_proto = nested_proto.list.list.add()
                _nested_to_proto(child, child_proto, process_leafs,
                                 already_processed)
    else:
        nested_proto.value = process_leafs(nested_value)