Exemple #1
0
def _deserialize_tensor_value(
        value_proto: executor_pb2.Value) -> _DeserializeReturnType:
    """Deserializes a tensor value from `executor_pb2.Value`.

  Args:
    value_proto: An instance of `executor_pb2.Value`.

  Returns:
    A tuple `(value, type_spec)`, where `value` is a Numpy array that represents
    the deserialized value, and `type_spec` is an instance of `tff.TensorType`
    that represents its type.

  Raises:
    TypeError: If the arguments are of the wrong types.
    ValueError: If the value is malformed.
  """
    which_value = value_proto.WhichOneof('value')
    if which_value != 'tensor':
        raise ValueError('Not a tensor value: {}'.format(which_value))

    # TODO(b/134543154): Find some way of creating the `TensorProto` using a
    # proper public interface rather than creating a dummy value that we will
    # overwrite right away.
    tensor_proto = tf.make_tensor_proto(values=0)
    if not value_proto.tensor.Unpack(tensor_proto):
        raise ValueError('Unable to unpack the received tensor value.')

    tensor_value = tf.make_ndarray(tensor_proto)
    value_type = computation_types.TensorType(
        dtype=tf.dtypes.as_dtype(tensor_proto.dtype),
        shape=tf.TensorShape(tensor_proto.tensor_shape))

    return tensor_value, value_type
Exemple #2
0
def deserialize_value(
        value_proto: executor_pb2.Value) -> _DeserializeReturnType:
    """Deserializes a value (of any type) from `executor_pb2.Value`.

  Args:
    value_proto: An instance of `executor_pb2.Value`.

  Returns:
    A tuple `(value, type_spec)`, where `value` is a deserialized
    representation of the transmitted value (e.g., Numpy array, or a
    `pb.Computation` instance), and `type_spec` is an instance of
    `tff.TensorType` that represents its type.

  Raises:
    TypeError: If the arguments are of the wrong types.
    ValueError: If the value is malformed.
  """
    py_typecheck.check_type(value_proto, executor_pb2.Value)
    which_value = value_proto.WhichOneof('value')
    if which_value == 'tensor':
        return _deserialize_tensor_value(value_proto)
    elif which_value == 'computation':
        return _deserialize_computation(value_proto)
    elif which_value == 'sequence':
        return _deserialize_sequence_value(value_proto.sequence)
    elif which_value == 'struct':
        return _deserialize_struct_value(value_proto)
    elif which_value == 'federated':
        return _deserialize_federated_value(value_proto)
    else:
        raise ValueError(
            'Unable to deserialize a value of type {}.'.format(which_value))
def deserialize_value(
    value_proto: executor_pb2.Value,
    type_hint: Optional[computation_types.Type] = None
) -> _DeserializeReturnType:
    """Deserializes a value (of any type) from `executor_pb2.Value`.

  Args:
    value_proto: An instance of `executor_pb2.Value`.
    type_hint: A `comptuations_types.Type` that hints at what the value type
      should be for executors that only return values.

  Returns:
    A tuple `(value, type_spec)`, where `value` is a deserialized
    representation of the transmitted value (e.g., Numpy array, or a
    `pb.Computation` instance), and `type_spec` is an instance of
    `tff.TensorType` that represents its type.

  Raises:
    TypeError: If the arguments are of the wrong types.
    ValueError: If the value is malformed.
  """
    if not hasattr(value_proto, 'WhichOneof'):
        raise TypeError(
            '`value_proto` must be a protocol buffer message with a '
            '`value` oneof field.')
    which_value = value_proto.WhichOneof('value')
    if which_value == 'tensor':
        return _deserialize_tensor_value(value_proto)
    elif which_value == 'computation':
        return _deserialize_computation(value_proto)
    elif which_value == 'sequence':
        return _deserialize_sequence_value(value_proto.sequence, type_hint)
    elif which_value == 'struct':
        return _deserialize_struct_value(value_proto, type_hint)
    elif which_value == 'federated':
        return _deserialize_federated_value(value_proto, type_hint)
    else:
        raise ValueError(
            'Unable to deserialize a value of type {}.'.format(which_value))