Ejemplo n.º 1
0
def stamp_parameter_in_graph(parameter_name, parameter_type, graph):
  """Stamps a parameter of a given type in the given tf.Graph instance.

  Tensors are stamped as placeholders, sequences are stamped as data sets
  constructed from string tensor handles, and named tuples are stamped by
  independently stamping their elements.

  Args:
    parameter_name: The suggested (string) name of the parameter to use in
      determining the names of the graph components to construct. The names that
      will actually appear in the graph are not guaranteed to be based on this
      suggested name, and may vary, e.g., due to existing naming conflicts, but
      a best-effort attempt will be made to make them similar for ease of
      debugging.
    parameter_type: The type of the parameter to stamp. Must be either an
      instance of computation_types.Type (or convertible to it), or None.
    graph: The instance of tf.Graph to stamp in.

  Returns:
    A tuple (val, binding), where 'val' is a Python object (such as a dataset,
    a placeholder, or a dictionary that represents a named tuple) that
    represents the stamped parameter for use in the body of a Python function
    that consumes this parameter, and the 'binding' is an instance of
    TensorFlow.Binding that indicates how parts of the type signature relate
    to the tensors and ops stamped into the graph.

  Raises:
    TypeError: If the arguments are of the wrong computation_types.
    ValueError: If the parameter type cannot be stamped in a TensorFlow graph.
  """
  py_typecheck.check_type(parameter_name, six.string_types)
  py_typecheck.check_type(graph, tf.Graph)
  if parameter_type is None:
    return (None, None)
  parameter_type = computation_types.to_type(parameter_type)
  if isinstance(parameter_type, computation_types.TensorType):
    with graph.as_default():
      placeholder = tf.placeholder(
          dtype=parameter_type.dtype,
          shape=parameter_type.shape,
          name=parameter_name)
      binding = pb.TensorFlow.Binding(
          tensor=pb.TensorFlow.TensorBinding(tensor_name=placeholder.name))
      return (placeholder, binding)
  elif isinstance(parameter_type, computation_types.NamedTupleType):
    element_name_value_pairs = []
    element_bindings = []
    for e in anonymous_tuple.to_elements(parameter_type):
      e_val, e_binding = stamp_parameter_in_graph(
          '{}_{}'.format(parameter_name, e[0]), e[1], graph)
      element_name_value_pairs.append((e[0], e_val))
      element_bindings.append(e_binding)
    return (anonymous_tuple.AnonymousTuple(element_name_value_pairs),
            pb.TensorFlow.Binding(
                tuple=pb.TensorFlow.NamedTupleBinding(
                    element=element_bindings)))
  elif isinstance(parameter_type, computation_types.SequenceType):
    with graph.as_default():
      handle = tf.placeholder(tf.string, shape=[])
    ds = make_dataset_from_string_handle(handle, parameter_type.element)
    return (ds,
            pb.TensorFlow.Binding(
                sequence=pb.TensorFlow.SequenceBinding(
                    iterator_string_handle_name=handle.name)))
  else:
    raise ValueError(
        'Parameter type component {} cannot be stamped into a TensorFlow '
        'graph.'.format(repr(parameter_type)))
Ejemplo n.º 2
0
 def test_tf_type_and_shape_with_unknown_dimension(self):
     s = (tf.int32, [None])
     t = computation_types.to_type(s)
     self.assertIsInstance(t, computation_types.TensorType)
     self.assertEqual(str(t), 'int32[?]')
Ejemplo n.º 3
0
 def test_tuple_of_tf_types(self):
     s = (tf.int32, tf.bool)
     t = computation_types.to_type(s)
     self.assertIsInstance(t, computation_types.StructWithPythonType)
     self.assertIs(t.python_container, tuple)
     self.assertEqual(str(t), '<int32,bool>')
Ejemplo n.º 4
0
def to_value(
    arg: Any,
    type_spec,
    context_stack: context_stack_base.ContextStack,
) -> ValueImpl:
    """Converts the argument into an instance of `tff.Value`.

  The types of non-`tff.Value` arguments that are currently convertible to
  `tff.Value` include the following:

  * Lists, tuples, anonymous tuples, named tuples, and dictionaries, all
    of which are converted into instances of `tff.Tuple`.
  * Placement literals, converted into instances of `tff.Placement`.
  * Computations.
  * Python constants of type `str`, `int`, `float`, `bool`
  * Numpy objects inherting from `np.ndarray` or `np.generic` (the parent
    of numpy scalar types)

  Args:
    arg: Either an instance of `tff.Value`, or an argument convertible to
      `tff.Value`. The argument must not be `None`.
    type_spec: An optional `computation_types.Type` or value convertible to it
      by `computation_types.to_type` which specifies the desired type signature
      of the resulting value. This allows for disambiguating the target type
      (e.g., when two TFF types can be mapped to the same Python
      representations), or `None` if none available, in which case TFF tries to
      determine the type of the TFF value automatically.
    context_stack: The context stack to use.

  Returns:
    An instance of `tff.Value` corresponding to the given `arg`, and of TFF type
    matching the `type_spec` if specified (not `None`).

  Raises:
    TypeError: if `arg` is of an unsupported type, or of a type that does not
      match `type_spec`. Raises explicit error message if TensorFlow constructs
      are encountered, as TensorFlow code should be sealed away from TFF
      federated context.
  """
    py_typecheck.check_type(context_stack, context_stack_base.ContextStack)
    if type_spec is not None:
        type_spec = computation_types.to_type(type_spec)
        type_utils.check_well_formed(type_spec)
    if isinstance(arg, ValueImpl):
        result = arg
    elif isinstance(arg, building_blocks.ComputationBuildingBlock):
        result = ValueImpl(arg, context_stack)
    elif isinstance(arg, placement_literals.PlacementLiteral):
        result = ValueImpl(building_blocks.Placement(arg), context_stack)
    elif isinstance(arg, computation_base.Computation):
        result = ValueImpl(
            building_blocks.CompiledComputation(
                computation_impl.ComputationImpl.get_proto(arg)),
            context_stack)
    elif type_spec is not None and isinstance(type_spec,
                                              computation_types.SequenceType):
        result = _wrap_sequence_as_value(arg, type_spec.element, context_stack)
    elif isinstance(arg, anonymous_tuple.AnonymousTuple):
        result = ValueImpl(
            building_blocks.Tuple([
                (k, ValueImpl.get_comp(to_value(v, None, context_stack)))
                for k, v in anonymous_tuple.iter_elements(arg)
            ]), context_stack)
    elif py_typecheck.is_named_tuple(arg):
        result = to_value(arg._asdict(), None, context_stack)  # pytype: disable=attribute-error
    elif py_typecheck.is_attrs(arg):
        result = to_value(
            attr.asdict(arg,
                        dict_factory=collections.OrderedDict,
                        recurse=False), None, context_stack)
    elif isinstance(arg, dict):
        if isinstance(arg, collections.OrderedDict):
            items = arg.items()
        else:
            items = sorted(arg.items())
        value = building_blocks.Tuple([
            (k, ValueImpl.get_comp(to_value(v, None, context_stack)))
            for k, v in items
        ])
        result = ValueImpl(value, context_stack)
    elif isinstance(arg, (tuple, list)):
        result = ValueImpl(
            building_blocks.Tuple([
                ValueImpl.get_comp(to_value(x, None, context_stack))
                for x in arg
            ]), context_stack)
    elif isinstance(arg, tensorflow_utils.TENSOR_REPRESENTATION_TYPES):
        result = _wrap_constant_as_value(arg, context_stack)
    elif isinstance(arg, (tf.Tensor, tf.Variable)):
        raise TypeError(
            'TensorFlow construct {} has been encountered in a federated '
            'context. TFF does not support mixing TF and federated orchestration '
            'code. Please wrap any TensorFlow constructs with '
            '`tff.tf_computation`.'.format(arg))
    elif isinstance(arg, function_utils.PolymorphicFunction):
        # TODO(b/129567727) remove this case when this is no longer an error
        raise TypeError(
            'Polymorphic computations cannot be converted to a TFF value. Consider '
            'explicitly specifying the argument types of a computation before '
            'passing it to a function that requires a TFF value (such as a TFF '
            'intrinsic like federated_map).')
    else:
        raise TypeError(
            'Unable to interpret an argument of type {} as a TFF value.'.
            format(py_typecheck.type_string(type(arg))))
    py_typecheck.check_type(result, ValueImpl)
    if (type_spec is not None and not type_utils.is_assignable_from(
            type_spec, result.type_signature)):
        raise TypeError(
            'The supplied argument maps to TFF type {}, which is incompatible with '
            'the requested type {}.'.format(result.type_signature, type_spec))
    return result
Ejemplo n.º 5
0
 def test_tf_tensorspec(self):
     s = tf.TensorSpec([None, 3], dtype=tf.float32)
     t = computation_types.to_type(s)
     self.assertIsInstance(t, computation_types.TensorType)
     self.assertEqual(str(t), 'float32[?,3]')
Ejemplo n.º 6
0
def pack_args_into_anonymous_tuple(args, kwargs, type_spec=None, context=None):
    """Packs positional and keyword arguments into an anonymous tuple.

  If 'type_spec' is not None, it must be a tuple type or something that's
  convertible to it by computation_types.to_type(). The assignment of arguments
  to fields of the tuple follows the same rule as during function calls. If
  'type_spec' is None, the positional arguments precede any of the keyword
  arguments, and the ordering of the keyword arguments matches the ordering in
  which they appear in kwargs. If the latter is an OrderedDict, the ordering
  will be preserved. On the other hand, if the latter is an ordinary unordered
  dict, the ordering is arbitrary.

  Args:
    args: Positional arguments.
    kwargs: Keyword arguments.
    type_spec: The optional type specification (either an instance of
      computation_types.NamedTupleType or something convertible to it), or None
      if there's no type. Used to drive the arrangements of args into fields of
      the constructed anonymous tuple, as noted in the description.
    context: The optional context (an instance of `context_base.Context`) in
      which the arguments are being packed. Required if and only if the
      `type_spec` is not `None`.

  Returns:
    An anoymous tuple containing all the arguments.

  Raises:
    TypeError: if the arguments are of the wrong computation_types.
  """
    type_spec = computation_types.to_type(type_spec)
    if not type_spec:
        return anonymous_tuple.AnonymousTuple([(None, arg) for arg in args] +
                                              list(six.iteritems(kwargs)))
    else:
        py_typecheck.check_type(type_spec, computation_types.NamedTupleType)
        py_typecheck.check_type(context, context_base.Context)
        if not is_argument_tuple(type_spec):
            raise TypeError(
                'Parameter type {} does not have a structure of an argument '
                'tuple, and cannot be populated from multiple positional and '
                'keyword arguments'.format(str(type_spec)))
        else:
            result_elements = []
            positions_used = set()
            keywords_used = set()
            for index, (name, elem_type) in enumerate(
                    anonymous_tuple.to_elements(type_spec)):
                if index < len(args):
                    if name is not None and name in kwargs:
                        raise TypeError(
                            'Argument {} specified twice.'.format(name))
                    else:
                        arg_value = args[index]
                        result_elements.append(
                            (name, context.ingest(arg_value, elem_type)))
                        positions_used.add(index)
                elif name is not None and name in kwargs:
                    arg_value = kwargs[name]
                    result_elements.append(
                        (name, context.ingest(arg_value, elem_type)))
                    keywords_used.add(name)
                elif name:
                    raise TypeError(
                        'Argument named {} is missing.'.format(name))
                else:
                    raise TypeError(
                        'Argument at position {} is missing.'.format(index))
            positions_missing = set(range(
                len(args))).difference(positions_used)
            if positions_missing:
                raise TypeError('Positional arguments at {} not used.'.format(
                    positions_missing))
            keywords_missing = set(kwargs.keys()).difference(keywords_used)
            if keywords_missing:
                raise TypeError('Keyword arguments at {} not used.'.format(
                    keywords_missing))
            return anonymous_tuple.AnonymousTuple(result_elements)
Ejemplo n.º 7
0
    def __init__(self,
                 initialize,
                 prepare,
                 work,
                 zero,
                 accumulate,
                 merge,
                 report,
                 bitwidth,
                 update,
                 server_state_label=None,
                 client_data_label=None):
        """Constructs a representation of a MapReduce-like iterative process.

    Note: All the computations supplied here as arguments must be TensorFlow
    computations, i.e., instances of `tff.Computation` constructed by the
    `tff.tf_computation` decorator/wrapper.

    Args:
      initialize: The computation that produces the initial server state.
      prepare: The computation that prepares the input for the clients.
      work: The client-side work computation.
      zero: The computation that produces the initial state for accumulators.
      accumulate: The computation that adds a client update to an accumulator.
      merge: The computation to use for merging pairs of accumulators.
      report: The computation that produces the final server-side aggregate for
        the top level accumulator (the global update).
      bitwidth: The computation that produces the bitwidth for secure sum.
      update: The computation that takes the global update and the server state
        and produces the new server state, as well as server-side output.
      server_state_label: Optional string label for the server state.
      client_data_label: Optional string label for the client data.

    Raises:
      TypeError: If the Python or TFF types of the arguments are invalid or not
        compatible with each other.
      AssertionError: If the manner in which the given TensorFlow computations
        are represented by TFF does not match what this code is expecting (this
        is an internal error that requires code update).
    """
        for label, comp in (
            ('initialize', initialize),
            ('prepare', prepare),
            ('work', work),
            ('zero', zero),
            ('accumulate', accumulate),
            ('merge', merge),
            ('report', report),
            ('bitwidth', bitwidth),
            ('update', update),
        ):
            _check_tensorflow_computation(label, comp)

        prepare_arg_type = prepare.type_signature.parameter
        init_result_type = initialize.type_signature.result
        if not _is_assignable_from_or_both_none(prepare_arg_type,
                                                init_result_type):
            raise TypeError(
                'The `prepare` computation expects an argument of type {}, '
                'which does not match the result type {} of `initialize`.'.
                format(prepare_arg_type, init_result_type))

        _check_accepts_two_tuple('work', work)
        work_2nd_arg_type = work.type_signature.parameter[1]
        prepare_result_type = prepare.type_signature.result
        if not _is_assignable_from_or_both_none(work_2nd_arg_type,
                                                prepare_result_type):
            raise TypeError(
                'The `work` computation expects an argument tuple with type {} as '
                'the second element (the initial client state from the server), '
                'which does not match the result type {} of `prepare`.'.format(
                    work_2nd_arg_type, prepare_result_type))

        _check_returns_two_tuple('work', work)

        py_typecheck.check_len(accumulate.type_signature.parameter, 2)
        accumulate.type_signature.parameter[0].check_assignable_from(
            zero.type_signature.result)
        accumulate_2nd_arg_type = accumulate.type_signature.parameter[1]
        work_client_update_type = work.type_signature.result[0]
        if not _is_assignable_from_or_both_none(accumulate_2nd_arg_type,
                                                work_client_update_type):

            raise TypeError(
                'The `accumulate` computation expects a second argument of type {}, '
                'which does not match the expected {} as implied by the type '
                'signature of `work`.'.format(accumulate_2nd_arg_type,
                                              work_client_update_type))
        accumulate.type_signature.parameter[0].check_assignable_from(
            accumulate.type_signature.result)

        py_typecheck.check_len(merge.type_signature.parameter, 2)
        merge.type_signature.parameter[0].check_assignable_from(
            accumulate.type_signature.result)
        merge.type_signature.parameter[1].check_assignable_from(
            accumulate.type_signature.result)
        merge.type_signature.parameter[0].check_assignable_from(
            merge.type_signature.result)

        report.type_signature.parameter.check_assignable_from(
            merge.type_signature.result)

        expected_update_parameter_type = computation_types.to_type([
            initialize.type_signature.result,
            [report.type_signature.result, work.type_signature.result[1]],
        ])
        if not _is_assignable_from_or_both_none(
                update.type_signature.parameter,
                expected_update_parameter_type):
            raise TypeError(
                'The `update` computation expects an argument of type {}, '
                'which does not match the expected {} as implied by the type '
                'signatures of `initialize`, `report`, and `work`.'.format(
                    update.type_signature.parameter,
                    expected_update_parameter_type))

        _check_returns_two_tuple('update', update)

        updated_state_type = update.type_signature.result[0]
        if not prepare_arg_type.is_assignable_from(updated_state_type):
            raise TypeError(
                'The `update` computation returns a result tuple whose first element '
                f'(the updated state type of the server) is type:\n'
                f'{updated_state_type}\n'
                f'which is not assignable to the state parameter type of `prepare`:\n'
                f'{prepare_arg_type}')

        self._initialize = initialize
        self._prepare = prepare
        self._work = work
        self._zero = zero
        self._accumulate = accumulate
        self._merge = merge
        self._report = report
        self._bitwidth = bitwidth
        self._update = update

        if server_state_label is not None:
            py_typecheck.check_type(server_state_label, str)
        self._server_state_label = server_state_label
        if client_data_label is not None:
            py_typecheck.check_type(client_data_label, str)
        self._client_data_label = client_data_label
Ejemplo n.º 8
0
def append_to_list_structure_for_element_type_spec(nested, value, type_spec):
    """Adds an element `value` to `nested` lists for `type_spec`.

  This function appends tensor-level constituents of an element `value` to the
  lists created by `make_empty_list_structure_for_element_type_spec`. The
  nested structure of `value` must match that created by the above function,
  and consistent with `type_spec`.

  Args:
    nested: Output of `make_empty_list_structure_for_element_type_spec`.
    value: A value (Python object) that a hierarchical structure of dictionary,
      list, and other containers holding tensor-like items that matches the
      hierarchy of `type_spec`.
    type_spec: An instance of `tff.Type` or something convertible to it, as in
      `make_empty_list_structure_for_element_type_spec`.

  Raises:
    TypeError: If the `type_spec` is not of a form described above, or the value
      is not of a type compatible with `type_spec`.
  """
    if value is None:
        return
    type_spec = computation_types.to_type(type_spec)
    # TODO(b/113116813): This could be made more efficient, but for now we won't
    # need to worry about it as this is an odd corner case.
    if isinstance(value, structure.Struct):
        elements = structure.to_elements(value)
        if all(k is not None for k, _ in elements):
            value = collections.OrderedDict(elements)
        elif all(k is None for k, _ in elements):
            value = tuple([v for _, v in elements])
        else:
            raise TypeError(
                'Expected an anonymous tuple to either have all elements named or '
                'all unnamed, got {}.'.format(value))
    if type_spec.is_tensor():
        py_typecheck.check_type(nested, list)
        # Convert the members to tensors to ensure that they are properly
        # typed and grouped before being passed to
        # tf.data.Dataset.from_tensor_slices.
        nested.append(tf.convert_to_tensor(value, type_spec.dtype))
    elif type_spec.is_struct():
        elements = structure.to_elements(type_spec)
        if isinstance(nested, collections.OrderedDict):
            if py_typecheck.is_named_tuple(value):
                # In Python 3.8 and later `_asdict` no longer return OrdereDict, rather
                # a regular `dict`.
                value = collections.OrderedDict(value._asdict())
            if isinstance(value, dict):
                if set(value.keys()) != set(k for k, _ in elements):
                    raise TypeError('Value {} does not match type {}.'.format(
                        value, type_spec))
                for elem_name, elem_type in elements:
                    append_to_list_structure_for_element_type_spec(
                        nested[elem_name], value[elem_name], elem_type)
            elif isinstance(value, (list, tuple)):
                if len(value) != len(elements):
                    raise TypeError('Value {} does not match type {}.'.format(
                        value, type_spec))
                for idx, (elem_name, elem_type) in enumerate(elements):
                    append_to_list_structure_for_element_type_spec(
                        nested[elem_name], value[idx], elem_type)
            else:
                raise TypeError(
                    'Unexpected type of value {} for TFF type {}.'.format(
                        py_typecheck.type_string(type(value)), type_spec))
        elif isinstance(nested, tuple):
            py_typecheck.check_type(value, (list, tuple))
            if len(value) != len(elements):
                raise TypeError('Value {} does not match type {}.'.format(
                    value, type_spec))
            for idx, (_, elem_type) in enumerate(elements):
                append_to_list_structure_for_element_type_spec(
                    nested[idx], value[idx], elem_type)
        else:
            raise TypeError(
                'Invalid nested structure, unexpected container type {}.'.
                format(py_typecheck.type_string(type(nested))))
    else:
        raise TypeError(
            'Expected a tensor or named tuple type, found {}.'.format(
                type_spec))
Ejemplo n.º 9
0
def make_data_set_from_elements(graph, elements, element_type):
    """Creates a `tf.data.Dataset` in `graph` from explicitly listed `elements`.

  Note: The underlying implementation attempts to use the
  `tf.data.Dataset.from_tensor_slices() method to build the data set quickly,
  but this doesn't always work. The typical scenario where it breaks is one
  with data set being composed of unequal batches. Typically, only the last
  batch is odd, so on the first attempt, we try to construct two data sets,
  one from all elements but the last one, and one from the last element, then
  concatenate the two. In the unlikely case that this fails (e.g., because
  all data set elements are batches of unequal sizes), we revert to the slow,
  but reliable method of constructing data sets from singleton elements, and
  then concatenating them all.

  Args:
    graph: The graph in which to construct the `tf.data.Dataset`, or `None` if
      the construction is to happen in the eager context.
    elements: A list of elements.
    element_type: The type of elements.

  Returns:
    The constructed `tf.data.Dataset` instance.

  Raises:
    TypeError: If element types do not match `element_type`.
    ValueError: If the elements are of incompatible types and shapes, or if
      no graph was specified outside of the eager context.
  """
    # Note: We allow the graph to be `None` to allow this function to be used in
    # the eager context.
    if graph is not None:
        py_typecheck.check_type(graph, tf.Graph)
    elif not tf.executing_eagerly():
        raise ValueError('Only in eager context may the graph be `None`.')
    py_typecheck.check_type(elements, list)
    element_type = computation_types.to_type(element_type)
    py_typecheck.check_type(element_type, computation_types.Type)

    def _make(element_subset):
        lists = make_empty_list_structure_for_element_type_spec(element_type)
        for el in element_subset:
            append_to_list_structure_for_element_type_spec(
                lists, el, element_type)
        tensor_slices = replace_empty_leaf_lists_with_numpy_arrays(
            lists, element_type)
        return tf.data.Dataset.from_tensor_slices(tensor_slices)

    def _work():  # pylint: disable=missing-docstring
        if not elements:
            # Just return an empty data set with the appropriate types.
            dummy_element = make_dummy_element_for_type_spec(element_type)
            ds = _make([dummy_element]).take(0)
        elif len(elements) == 1:
            ds = _make(elements)
        else:
            try:
                # It is common for the last element to be a batch of a size different
                # from all the preceding batches. With this in mind, we proactively
                # single out the last element (optimizing for the common case).
                ds = _make(elements[0:-1]).concatenate(_make(elements[-1:]))
            except ValueError:
                # In case elements beyond just the last one are of unequal shapes, we
                # may have failed (the most likely cause), so fall back onto the slow
                # process of constructing and joining data sets from singletons. Not
                # optimizing this for now, as it's very unlikely in scenarios
                # we're targeting.
                #
                # Note: this will not remain `None` because `element`s is not empty.
                ds = None
                ds = typing.cast(tf.data.Dataset, ds)
                for i in range(len(elements)):
                    singleton_ds = _make(elements[i:i + 1])
                    ds = singleton_ds if ds is None else ds.concatenate(
                        singleton_ds)
        ds_element_type = computation_types.to_type(ds.element_spec)
        if not element_type.is_assignable_from(ds_element_type):
            raise TypeError(
                'Failure during data set construction, expected elements of type {}, '
                'but the constructed data set has elements of type {}.'.format(
                    element_type, ds_element_type))
        return ds

    if graph is not None:
        with graph.as_default():
            return _work()
    else:
        return _work()
Ejemplo n.º 10
0
def capture_result_from_graph(result, graph):
    """Captures a result stamped into a tf.Graph as a type signature and binding.

  Args:
    result: The result to capture, a Python object that is composed of tensors,
      possibly nested within Python structures such as dictionaries, lists,
      tuples, or named tuples.
    graph: The instance of tf.Graph to use.

  Returns:
    A tuple (type_spec, binding), where 'type_spec' is an instance of
    computation_types.Type that describes the type of the result, and 'binding'
    is an instance of TensorFlow.Binding that indicates how parts of the result
    type relate to the tensors and ops that appear in the result.

  Raises:
    TypeError: If the argument or any of its parts are of an uexpected type.
  """
    def _get_bindings_for_elements(name_value_pairs, graph, type_fn):
        """Build `(type_spec, binding)` tuple for name value pairs."""
        element_name_type_binding_triples = [
            ((k, ) + capture_result_from_graph(v, graph))
            for k, v in name_value_pairs
        ]
        type_spec = type_fn([((e[0], e[1]) if e[0] else e[1])
                             for e in element_name_type_binding_triples])
        binding = pb.TensorFlow.Binding(struct=pb.TensorFlow.StructBinding(
            element=[e[2] for e in element_name_type_binding_triples]))
        return type_spec, binding

    # TODO(b/113112885): The emerging extensions for serializing SavedModels may
    # end up introducing similar concepts of bindings, etc., we should look here
    # into the possibility of reusing some of that code when it's available.
    if isinstance(result, TENSOR_REPRESENTATION_TYPES):
        with graph.as_default():
            result = tf.constant(result)
    if tf.is_tensor(result):
        if hasattr(result, 'read_value'):
            # We have a tf.Variable-like result, get a proper tensor to fetch.
            with graph.as_default():
                result = result.read_value()
        return (computation_types.TensorType(result.dtype.base_dtype,
                                             result.shape),
                pb.TensorFlow.Binding(tensor=pb.TensorFlow.TensorBinding(
                    tensor_name=result.name)))
    elif py_typecheck.is_named_tuple(result):
        # Special handling needed for collections.namedtuples since they do not have
        # anything in the way of a shared base class. Note we don't want to rely on
        # the fact that collections.namedtuples inherit from 'tuple' because we'd be
        # failing to retain the information about naming of tuple members.
        # pylint: disable=protected-access
        name_value_pairs = result._asdict().items()
        # pylint: enable=protected-access
        return _get_bindings_for_elements(
            name_value_pairs, graph,
            functools.partial(computation_types.StructWithPythonType,
                              container_type=type(result)))
    elif py_typecheck.is_attrs(result):
        name_value_pairs = attr.asdict(result,
                                       dict_factory=collections.OrderedDict,
                                       recurse=False)
        return _get_bindings_for_elements(
            name_value_pairs.items(), graph,
            functools.partial(computation_types.StructWithPythonType,
                              container_type=type(result)))
    elif isinstance(result, structure.Struct):
        return _get_bindings_for_elements(structure.to_elements(result), graph,
                                          computation_types.StructType)
    elif isinstance(result, collections.Mapping):
        if isinstance(result, collections.OrderedDict):
            name_value_pairs = result.items()
        else:
            name_value_pairs = sorted(result.items())
        return _get_bindings_for_elements(
            name_value_pairs, graph,
            functools.partial(computation_types.StructWithPythonType,
                              container_type=type(result)))
    elif isinstance(result, (list, tuple)):
        element_type_binding_pairs = [
            capture_result_from_graph(e, graph) for e in result
        ]
        return (computation_types.StructWithPythonType(
            [e[0] for e in element_type_binding_pairs], type(result)),
                pb.TensorFlow.Binding(struct=pb.TensorFlow.StructBinding(
                    element=[e[1] for e in element_type_binding_pairs])))
    elif isinstance(result, type_conversions.TF_DATASET_REPRESENTATION_TYPES):
        variant_tensor = tf.data.experimental.to_variant(result)
        element_structure = result.element_spec
        try:
            element_type = computation_types.to_type(element_structure)
        except TypeError as e:
            raise TypeError(
                'TFF does not support Datasets that yield elements of structure {!s}'
                .format(element_structure)) from e
        return (computation_types.SequenceType(element_type),
                pb.TensorFlow.Binding(sequence=pb.TensorFlow.SequenceBinding(
                    variant_tensor_name=variant_tensor.name)))
    else:
        raise TypeError(
            'Cannot capture a result of an unsupported type {}.'.format(
                py_typecheck.type_string(type(result))))
Ejemplo n.º 11
0
def assemble_result_from_graph(type_spec, binding, output_map):
    """Assembles a result stamped into a `tf.Graph` given type signature/binding.

  This method does roughly the opposite of `capture_result_from_graph`, in that
  whereas `capture_result_from_graph` starts with a single structured object
  made up of tensors and computes its type and bindings, this method starts
  with the type/bindings and constructs a structured object made up of tensors.

  Args:
    type_spec: The type signature of the result to assemble, an instance of
      `types.Type` or something convertible to it.
    binding: The binding that relates the type signature to names of tensors in
      the graph, an instance of `pb.TensorFlow.Binding`.
    output_map: The mapping from tensor names that appear in the binding to
      actual stamped tensors (possibly renamed during import).

  Returns:
    The assembled result, a Python object that is composed of tensors, possibly
    nested within Python structures such as anonymous tuples.

  Raises:
    TypeError: If the argument or any of its parts are of an uexpected type.
    ValueError: If the arguments are invalid or inconsistent witch other, e.g.,
      the type and binding don't match, or the tensor is not found in the map.
  """
    type_spec = computation_types.to_type(type_spec)
    py_typecheck.check_type(type_spec, computation_types.Type)
    py_typecheck.check_type(binding, pb.TensorFlow.Binding)
    py_typecheck.check_type(output_map, dict)
    for k, v in output_map.items():
        py_typecheck.check_type(k, str)
        if not tf.is_tensor(v):
            raise TypeError(
                'Element with key {} in the output map is {}, not a tensor.'.
                format(k, py_typecheck.type_string(type(v))))

    binding_oneof = binding.WhichOneof('binding')
    if type_spec.is_tensor():
        if binding_oneof != 'tensor':
            raise ValueError(
                'Expected a tensor binding, found {}.'.format(binding_oneof))
        elif binding.tensor.tensor_name not in output_map:
            raise ValueError(
                'Tensor named {} not found in the output map.'.format(
                    binding.tensor.tensor_name))
        else:
            return output_map[binding.tensor.tensor_name]
    elif type_spec.is_struct():
        if binding_oneof != 'struct':
            raise ValueError(
                'Expected a struct binding, found {}.'.format(binding_oneof))
        else:
            type_elements = structure.to_elements(type_spec)
            if len(binding.struct.element) != len(type_elements):
                raise ValueError(
                    'Mismatching tuple sizes in type ({}) and binding ({}).'.
                    format(len(type_elements), len(binding.struct.element)))
            result_elements = []
            for (element_name,
                 element_type), element_binding in zip(type_elements,
                                                       binding.struct.element):
                element_object = assemble_result_from_graph(
                    element_type, element_binding, output_map)
                result_elements.append((element_name, element_object))
            if type_spec.python_container is None:
                return structure.Struct(result_elements)
            container_type = type_spec.python_container
            if (py_typecheck.is_named_tuple(container_type)
                    or py_typecheck.is_attrs(container_type)):
                return container_type(**dict(result_elements))
            return container_type(result_elements)
    elif type_spec.is_sequence():
        if binding_oneof != 'sequence':
            raise ValueError(
                'Expected a sequence binding, found {}.'.format(binding_oneof))
        else:
            sequence_oneof = binding.sequence.WhichOneof('binding')
            if sequence_oneof == 'variant_tensor_name':
                variant_tensor = output_map[
                    binding.sequence.variant_tensor_name]
                return make_dataset_from_variant_tensor(
                    variant_tensor, type_spec.element)
            else:
                raise ValueError('Unsupported sequence binding \'{}\'.'.format(
                    sequence_oneof))
    else:
        raise ValueError('Unsupported type \'{}\'.'.format(type_spec))
Ejemplo n.º 12
0
def stamp_parameter_in_graph(parameter_name, parameter_type, graph):
    """Stamps a parameter of a given type in the given tf.Graph instance.

  Tensors are stamped as placeholders, sequences are stamped as data sets
  constructed from string tensor handles, and named tuples are stamped by
  independently stamping their elements.

  Args:
    parameter_name: The suggested (string) name of the parameter to use in
      determining the names of the graph components to construct. The names that
      will actually appear in the graph are not guaranteed to be based on this
      suggested name, and may vary, e.g., due to existing naming conflicts, but
      a best-effort attempt will be made to make them similar for ease of
      debugging.
    parameter_type: The type of the parameter to stamp. Must be either an
      instance of computation_types.Type (or convertible to it), or None.
    graph: The instance of tf.Graph to stamp in.

  Returns:
    A tuple (val, binding), where 'val' is a Python object (such as a dataset,
    a placeholder, or a `structure.Struct` that represents a named
    tuple) that represents the stamped parameter for use in the body of a Python
    function that consumes this parameter, and the 'binding' is an instance of
    TensorFlow.Binding that indicates how parts of the type signature relate
    to the tensors and ops stamped into the graph.

  Raises:
    TypeError: If the arguments are of the wrong computation_types.
    ValueError: If the parameter type cannot be stamped in a TensorFlow graph.
  """
    py_typecheck.check_type(parameter_name, str)
    py_typecheck.check_type(graph, tf.Graph)
    if parameter_type is None:
        return (None, None)
    parameter_type = computation_types.to_type(parameter_type)
    if parameter_type.is_tensor():
        with graph.as_default():
            placeholder = tf.compat.v1.placeholder(dtype=parameter_type.dtype,
                                                   shape=parameter_type.shape,
                                                   name=parameter_name)
            binding = pb.TensorFlow.Binding(tensor=pb.TensorFlow.TensorBinding(
                tensor_name=placeholder.name))
            return (placeholder, binding)
    elif parameter_type.is_struct():
        # The parameter_type could be a StructTypeWithPyContainer, however, we
        # ignore that for now. Instead, the proper containers will be inserted at
        # call time by function_utils.wrap_as_zero_or_one_arg_callable.
        if not parameter_type:
            # Stamps dummy element to "populate" graph, as TensorFlow does not support
            # empty graphs.
            dummy_tensor = tf.no_op()
        element_name_value_pairs = []
        element_bindings = []
        for e in structure.iter_elements(parameter_type):
            e_val, e_binding = stamp_parameter_in_graph(
                '{}_{}'.format(parameter_name, e[0]), e[1], graph)
            element_name_value_pairs.append((e[0], e_val))
            element_bindings.append(e_binding)
        return (structure.Struct(element_name_value_pairs),
                pb.TensorFlow.Binding(struct=pb.TensorFlow.StructBinding(
                    element=element_bindings)))
    elif parameter_type.is_sequence():
        with graph.as_default():
            variant_tensor = tf.compat.v1.placeholder(tf.variant, shape=[])
            ds = make_dataset_from_variant_tensor(variant_tensor,
                                                  parameter_type.element)
        return (ds,
                pb.TensorFlow.Binding(sequence=pb.TensorFlow.SequenceBinding(
                    variant_tensor_name=variant_tensor.name)))
    else:
        raise ValueError(
            'Parameter type component {!r} cannot be stamped into a TensorFlow '
            'graph.'.format(parameter_type))
Ejemplo n.º 13
0
def make_data_set_from_elements(graph, elements, element_type):
  """Creates a `tf.data.Dataset` in `graph` from explicitly listed `elements`.

  NOTE: The underlying implementation attempts to use the
  `tf.data.Dataset.from_tensor_slices() method to build the data set quickly,
  but this doesn't always work. The typical scenario where it breaks is one
  with data set being composed of unequal batches. Typically, only the last
  batch is odd, so on the first attempt, we try to construct two data sets,
  one from all elements but the last one, and one from the last element, then
  concatenate the two. In the unlikely case that this fails (e.g., because
  all data set elements are batches of unequal sizes), we revert to the slow,
  but reliable method of constructing data sets from singleton elements, and
  then concatenating them all.

  Args:
    graph: The graph in which to construct the `tf.data.Dataset`.
    elements: A list of elements.
    element_type: The type of elements.

  Returns:
    The constructed `tf.data.Dataset` instance.

  Raises:
    TypeError: If element types do not match `element_type`.
    ValueError: If the elements are of incompatible types and shapes.
  """
  py_typecheck.check_type(graph, tf.Graph)
  py_typecheck.check_type(elements, list)
  element_type = computation_types.to_type(element_type)
  py_typecheck.check_type(element_type, computation_types.Type)

  def _make(element_subset):
    structure = make_empty_list_structure_for_element_type_spec(element_type)
    for el in element_subset:
      append_to_list_structure_for_element_type_spec(structure, el,
                                                     element_type)
    tensor_slices = to_tensor_slices_from_list_structure_for_element_type_spec(
        structure, element_type)
    return tf.data.Dataset.from_tensor_slices(tensor_slices)

  with graph.as_default():
    if not elements:
      # Just return an empty data set with the appropriate types.
      dummy_element = _make_dummy_element_for_type_spec(element_type)
      ds = _make([dummy_element]).take(0)
    elif len(elements) == 1:
      ds = _make(elements)
    else:
      try:
        # It is common for the last element to be a batch of a size different
        # from all the preceding batches. With this in mind, we proactively
        # single out the last element (optimizing for the common case).
        ds = _make(elements[0:-1]).concatenate(_make(elements[-1:]))
      except ValueError:
        # In case elements beyond just the last one are of unequal shapes, we
        # may have failed (the most likely cause), so fall back onto the slow
        # process of constructing and joining data sets from singletons. Not
        # optimizing this for now, as it's very unlikely in scenarios
        # we're targeting.
        ds = None
        for i in range(len(elements)):
          singleton_ds = _make(elements[i:i + 1])
          ds = singleton_ds if ds is None else ds.concatenate(singleton_ds)
    ds_element_type = type_utils.tf_dtypes_and_shapes_to_type(
        ds.output_types, ds.output_shapes)
    if not type_utils.is_assignable_from(element_type, ds_element_type):
      raise TypeError(
          'Failure during data set construction, expected elements of type {}, '
          'but the constructed data set has elements of type {}.'.format(
              str(element_type), str(ds_element_type)))
  return ds
Ejemplo n.º 14
0
def append_to_list_structure_for_element_type_spec(structure, value, type_spec):
  """Adds an element `value` to a nested `structure` of lists for `type_spec`.

  This function appends tensor-level constituents of an element `value` to the
  lists created by `make_empty_list_structure_for_element_type_spec`. The
  nested structure of `value` must match that created by the above function,
  and consistent with `type_spec`.

  Args:
    structure: Output of `make_empty_list_structure_for_element_type_spec`.
    value: A value (Python object) that a hierarchical structure of dictionary,
      list, and other containers holding tensor-like items that matches the
      hierarchy of `type_spec`.
    type_spec: An instance of `tff.Type` or something convertible to it, as in
      `make_empty_list_structure_for_element_type_spec`.

  Raises:
    TypeError: If the `type_spec` is not of a form described above, or the value
      is not of a type compatible with `type_spec`.
  """
  if value is None:
    return
  type_spec = computation_types.to_type(type_spec)
  py_typecheck.check_type(type_spec, computation_types.Type)
  # TODO(b/113116813): This could be made more efficient, but for now we won't
  # need to worry about it as this is an odd corner case.
  if isinstance(value, anonymous_tuple.AnonymousTuple):
    value = collections.OrderedDict(anonymous_tuple.to_elements(value))
  if isinstance(type_spec, computation_types.TensorType):
    py_typecheck.check_type(structure, list)
    structure.append(value)
  elif isinstance(type_spec, computation_types.NamedTupleType):
    elements = anonymous_tuple.to_elements(type_spec)
    if isinstance(structure, collections.OrderedDict):
      if py_typecheck.is_named_tuple(value):
        value = value._asdict()
      if isinstance(value, dict):
        if set(value.keys()) != set(k for k, _ in elements):
          raise TypeError('Value {} does not match type {}.'.format(
              str(value), str(type_spec)))
        for elem_name, elem_type in elements:
          append_to_list_structure_for_element_type_spec(
              structure[elem_name], value[elem_name], elem_type)
      elif isinstance(value, (list, tuple)):
        if len(value) != len(elements):
          raise TypeError('Value {} does not match type {}.'.format(
              str(value), str(type_spec)))
        for idx, (elem_name, elem_type) in enumerate(elements):
          append_to_list_structure_for_element_type_spec(
              structure[elem_name], value[idx], elem_type)
      else:
        raise TypeError('Unexpected type of value {} for TFF type {}.'.format(
            py_typecheck.type_string(type(value)), str(type_spec)))
    elif isinstance(structure, tuple):
      py_typecheck.check_type(value, (list, tuple))
      if len(value) != len(elements):
        raise TypeError('Value {} does not match type {}.'.format(
            str(value), str(type_spec)))
      for idx, (_, elem_type) in enumerate(elements):
        append_to_list_structure_for_element_type_spec(structure[idx],
                                                       value[idx], elem_type)
    else:
      raise TypeError(
          'Invalid nested structure, unexpected container type {}.'.format(
              py_typecheck.type_string(type(structure))))
  else:
    raise TypeError('Expected a tensor or named tuple type, found {}.'.format(
        str(type_spec)))
Ejemplo n.º 15
0
def build_jax_federated_averaging_process(batch_type, model_type, loss_fn,
                                          step_size):
    """Constructs an iterative process that implements simple federated averaging.

  Args:
    batch_type: An instance of `tff.Type` that represents the type of a single
      batch of data to use for training. This type should be constructed with
      standard Python containers (such as `collections.OrderedDict`) of the sort
      that are expected as parameters to `loss_fn`.
    model_type: An instance of `tff.Type` that represents the type of the model.
      Similarly to `batch_size`, this type should be constructed with standard
      Python containers (such as `collections.OrderedDict`) of the sort that are
      expected as parameters to `loss_fn`.
    loss_fn: A loss function for the model. Must be a Python function that takes
      two parameters, one of them being the model, and the other being a single
      batch of data (with types matching `batch_type` and `model_type`).
    step_size: The step size to use during training (an `np.float32`).

  Returns:
    An instance of `tff.templates.IterativeProcess` that implements federated
    training in JAX.
  """
    batch_type = computation_types.to_type(batch_type)
    model_type = computation_types.to_type(model_type)

    py_typecheck.check_type(batch_type, computation_types.Type)
    py_typecheck.check_type(model_type, computation_types.Type)
    py_typecheck.check_callable(loss_fn)
    py_typecheck.check_type(step_size, np.float)

    def _tensor_zeros(tensor_type):
        return jax.numpy.zeros(tensor_type.shape.dims,
                               dtype=tensor_type.dtype.as_numpy_dtype)

    @experimental_computations.jax_computation
    def _create_zero_model():
        model_zeros = structure.map_structure(_tensor_zeros, model_type)
        return type_conversions.type_to_py_container(model_zeros, model_type)

    @computations.federated_computation
    def _create_zero_model_on_server():
        return intrinsics.federated_eval(_create_zero_model, placements.SERVER)

    def _apply_update(model_param, param_delta):
        return model_param - step_size * param_delta

    @experimental_computations.jax_computation(model_type, batch_type)
    def _train_on_one_batch(model, batch):
        params = structure.flatten(
            structure.from_container(model, recursive=True))
        grads = structure.flatten(
            structure.from_container(jax.api.grad(loss_fn)(model, batch)))
        updated_params = [_apply_update(x, y) for (x, y) in zip(params, grads)]
        trained_model = structure.pack_sequence_as(model_type, updated_params)
        return type_conversions.type_to_py_container(trained_model, model_type)

    local_dataset_type = computation_types.SequenceType(batch_type)

    @computations.federated_computation(model_type, local_dataset_type)
    def _train_on_one_client(model, batches):
        return intrinsics.sequence_reduce(batches, model, _train_on_one_batch)

    @computations.federated_computation(
        computation_types.FederatedType(model_type, placements.SERVER),
        computation_types.FederatedType(local_dataset_type,
                                        placements.CLIENTS))
    def _train_one_round(model, federated_data):
        locally_trained_models = intrinsics.federated_map(
            _train_on_one_client,
            collections.OrderedDict([('model',
                                      intrinsics.federated_broadcast(model)),
                                     ('batches', federated_data)]))
        return intrinsics.federated_mean(locally_trained_models)

    return iterative_process.IterativeProcess(
        initialize_fn=_create_zero_model_on_server, next_fn=_train_one_round)
Ejemplo n.º 16
0
def pack_args(parameter_type, args: Sequence[Any], kwargs: Mapping[str, Any],
              context: context_base.Context):
    """Pack arguments into a single one that matches the given parameter type.

  The arguments may or may not be packed into a tuple, depending on the type of
  the parameter, and how many arguments are present.

  Args:
    parameter_type: The type of the single parameter expected by a computation,
      an instance of computation_types.Type or something convertible to it, or
      None if the computation is not expecting a parameter.
    args: Positional arguments of a call.
    kwargs: Keyword arguments of a call.
    context: The context (an instance of `context_base.Context`) in which the
      arguments are being packed.

  Returns:
    A single value object of type that matches 'parameter_type' that contains
    all the arguments, or None if the 'parameter_type' is None.

  Raises:
    TypeError: if the args/kwargs do not match the given parameter type.
  """
    py_typecheck.check_type(context, context_base.Context)
    if parameter_type is None:
        # If there's no parameter type, there should be no args of any kind.
        if args or kwargs:
            raise TypeError('Was not expecting any arguments.')
        else:
            return None
    else:
        parameter_type = computation_types.to_type(parameter_type)
        if not args and not kwargs:
            raise TypeError(
                'Declared a parameter of type {}, but got no arguments.'.
                format(parameter_type))
        else:
            single_positional_arg = (len(args) == 1) and not kwargs
            if not isinstance(parameter_type,
                              computation_types.NamedTupleType):
                # If not a named tuple type, a single positional argument is the only
                # supported call style.
                if not single_positional_arg:
                    raise TypeError(
                        'Parameter type {} is compatible only with a single positional '
                        'argument, but found {} positional and {} keyword args.'
                        .format(parameter_type, len(args), len(kwargs)))
                else:
                    arg = args[0]
            elif single_positional_arg:
                arg = args[0]
            elif not is_argument_tuple(parameter_type):
                raise TypeError(
                    'Parameter type {} does not have a structure of an argument '
                    'tuple, and cannot be populated from multiple positional and '
                    'keyword arguments; please construct a tuple before the '
                    'call.'.format(parameter_type))
            else:
                arg = pack_args_into_anonymous_tuple(args, kwargs,
                                                     parameter_type, context)
            return context.ingest(arg, parameter_type)
Ejemplo n.º 17
0
def infer_type(arg: Any) -> Optional[computation_types.Type]:
  """Infers the TFF type of the argument (a `computation_types.Type` instance).

  WARNING: This function is only partially implemented.

  The kinds of arguments that are currently correctly recognized:
  - tensors, variables, and data sets,
  - things that are convertible to tensors (including numpy arrays, builtin
    types, as well as lists and tuples of any of the above, etc.),
  - nested lists, tuples, namedtuples, anonymous tuples, dict, and OrderedDicts.

  Args:
    arg: The argument, the TFF type of which to infer.

  Returns:
    Either an instance of `computation_types.Type`, or `None` if the argument is
    `None`.
  """
  # TODO(b/113112885): Implement the remaining cases here on the need basis.
  if arg is None:
    return None
  elif isinstance(arg, typed_object.TypedObject):
    return arg.type_signature
  elif tf.is_tensor(arg):
    return computation_types.TensorType(arg.dtype.base_dtype, arg.shape)
  elif isinstance(arg, TF_DATASET_REPRESENTATION_TYPES):
    element_type = computation_types.to_type(arg.element_spec)
    return computation_types.SequenceType(element_type)
  elif isinstance(arg, structure.Struct):
    return computation_types.StructType([
        (k, infer_type(v)) if k else infer_type(v)
        for k, v in structure.iter_elements(arg)
    ])
  elif py_typecheck.is_attrs(arg):
    items = attr.asdict(
        arg, dict_factory=collections.OrderedDict, recurse=False)
    return computation_types.StructWithPythonType(
        [(k, infer_type(v)) for k, v in items.items()], type(arg))
  elif py_typecheck.is_named_tuple(arg):
    # In Python 3.8 and later `_asdict` no longer return OrdereDict, rather a
    # regular `dict`.
    items = collections.OrderedDict(arg._asdict())
    return computation_types.StructWithPythonType(
        [(k, infer_type(v)) for k, v in items.items()], type(arg))
  elif isinstance(arg, dict):
    if isinstance(arg, collections.OrderedDict):
      items = arg.items()
    else:
      items = sorted(arg.items())
    return computation_types.StructWithPythonType(
        [(k, infer_type(v)) for k, v in items], type(arg))
  elif isinstance(arg, (tuple, list)):
    elements = []
    all_elements_named = True
    for element in arg:
      all_elements_named &= py_typecheck.is_name_value_pair(element)
      elements.append(infer_type(element))
    # If this is a tuple of (name, value) pairs, the caller most likely intended
    # this to be a StructType, so we avoid storing the Python container.
    if elements and all_elements_named:
      return computation_types.StructType(elements)
    else:
      return computation_types.StructWithPythonType(elements, type(arg))
  elif isinstance(arg, str):
    return computation_types.TensorType(tf.string)
  elif isinstance(arg, (np.generic, np.ndarray)):
    return computation_types.TensorType(
        tf.dtypes.as_dtype(arg.dtype), arg.shape)
  else:
    arg_type = type(arg)
    if arg_type is bool:
      return computation_types.TensorType(tf.bool)
    elif arg_type is int:
      # Chose the integral type based on value.
      if arg > tf.int64.max or arg < tf.int64.min:
        raise TypeError('No integral type support for values outside range '
                        f'[{tf.int64.min}, {tf.int64.max}]. Got: {arg}')
      elif arg > tf.int32.max or arg < tf.int32.min:
        return computation_types.TensorType(tf.int64)
      else:
        return computation_types.TensorType(tf.int32)
    elif arg_type is float:
      return computation_types.TensorType(tf.float32)
    else:
      # Now fall back onto the heavier-weight processing, as all else failed.
      # Use make_tensor_proto() to make sure to handle it consistently with
      # how TensorFlow is handling values (e.g., recognizing int as int32, as
      # opposed to int64 as in NumPy).
      try:
        # TODO(b/113112885): Find something more lightweight we could use here.
        tensor_proto = tf.make_tensor_proto(arg)
        return computation_types.TensorType(
            tf.dtypes.as_dtype(tensor_proto.dtype),
            tf.TensorShape(tensor_proto.tensor_shape))
      except TypeError as err:
        raise TypeError('Could not infer the TFF type of {}: {}'.format(
            py_typecheck.type_string(type(arg)), err))
Ejemplo n.º 18
0
def serialize_jax_computation(traced_fn, arg_fn, parameter_type,
                              context_stack):
    """Serializes a Python function containing JAX code as a TFF computation.

  Args:
    traced_fn: The Python function containing JAX code to be traced by JAX and
      serialized as a TFF computation containing XLA code.
    arg_fn: An unpacking function that takes a TFF argument, and returns a combo
      of (args, kwargs) to invoke `traced_fn` with (e.g., as the one constructed
      by `function_utils.create_argument_unpacking_fn`).
    parameter_type: An instance of `computation_types.Type` that represents the
      TFF type of the computation parameter, or `None` if the function does not
      take any parameters.
    context_stack: The context stack to use during serialization.

  Returns:
    An instance of `pb.Computation` with the constructed computation.

  Raises:
    TypeError: if the arguments are of the wrong types.
  """
    py_typecheck.check_callable(traced_fn)
    py_typecheck.check_callable(arg_fn)
    py_typecheck.check_type(context_stack, context_stack_base.ContextStack)

    if parameter_type is not None:
        parameter_type = computation_types.to_type(parameter_type)
        packed_arg = _tff_type_to_xla_serializer_arg(parameter_type)
    else:
        packed_arg = None

    args, kwargs = arg_fn(packed_arg)

    # While the fake parameters are fed via args/kwargs during serialization,
    # it is possible for them to get reorderd in the actual generate XLA code.
    # We use here the same flatenning function as that one, which is used by
    # the JAX serializer to determine the orderding and allow it to be captured
    # in the parameter binding. We do not need to do anything special for the
    # results, since the results, if multiple, are always returned as a tuple.
    flattened_obj, _ = jax.tree_util.tree_flatten((args, kwargs))
    tensor_indexes = list(np.argsort([x.tensor_index for x in flattened_obj]))

    def _adjust_arg(x):
        if isinstance(x, structure.Struct):
            return type_conversions.type_to_py_container(x, x.type_signature)
        else:
            return x

    args = [_adjust_arg(x) for x in args]
    kwargs = {k: _adjust_arg(v) for k, v in kwargs.items()}

    context = jax_computation_context.JaxComputationContext()
    with context_stack.install(context):
        tracer_callable = jax.xla_computation(traced_fn,
                                              tuple_args=True,
                                              return_shape=True)
        compiled_xla, returned_shape = tracer_callable(*args, **kwargs)

    if isinstance(returned_shape, jax.ShapeDtypeStruct):
        returned_type_spec = _jax_shape_dtype_struct_to_tff_tensor(
            returned_shape)
    else:
        returned_type_spec = computation_types.to_type(
            structure.map_structure(
                _jax_shape_dtype_struct_to_tff_tensor,
                structure.from_container(returned_shape, recursive=True)))

    computation_type = computation_types.FunctionType(parameter_type,
                                                      returned_type_spec)
    return xla_serialization.create_xla_tff_computation(
        compiled_xla, tensor_indexes, computation_type)
Ejemplo n.º 19
0
def wrap_as_zero_or_one_arg_callable(fn, parameter_type=None, unpack=None):
    """Wraps around `fn` so it accepts up to one positional TFF-typed argument.

  This function helps to simplify dealing with functions and defuns that might
  have diverse and complex signatures, but that represent computations and as
  such, conceptually only accept a single parameter. The returned callable has
  a single positional parameter or no parameters. If it has one parameter, the
  parameter is expected to contain all arguments required by `fn` and matching
  the supplied parameter type signature bundled together into an anonymous
  tuple, if needed. The callable unpacks that structure, and passes all of
  its elements as positional or keyword-based arguments in the call to `fn`.

  Example usage:

    @tf.function
    def my_fn(x, y, z=10, name='bar', *p, **q):
      return x + y

    type_spec = (tf.int32, tf.int32)

    wrapped_fn = wrap_as_zero_or_one_arg_callable(my_fn, type_spec)

    arg = AnonymoutTuple([('x', 10), ('y', 20)])

    ... = wrapped_fn(arg)

  Args:
    fn: The underlying backend function or defun to invoke with the unpacked
      arguments.
    parameter_type: The TFF type of the parameter bundle to be accepted by the
      returned callable, if any, or None if there's no parameter.
    unpack: Whether to break the parameter down into constituent parts and feed
      them as arguments to `fn` (True), leave the parameter as is and pass it to
      `fn` as a single unit (False), or allow it to be inferred from the
      signature of `fn` (None). In the latter case (None), if any ambiguity
      arises, an exception is thrown. If the parameter_type is None, this value
      has no effect, and is simply ignored.

  Returns:
    The zero- or one-argument callable that invokes `fn` with the unbundled
    arguments, as described above.

  Raises:
    TypeError: if arguments to this call are of the wrong types, or if the
      supplied 'parameter_type' is not compatible with `fn`.
  """
    # TODO(b/113112885): Revisit whether the 3-way 'unpack' knob is sufficient
    # for our needs, or more options are needed.
    if unpack not in [True, False, None]:
        raise TypeError(
            'The unpack argument has an unexpected value {}.'.format(
                repr(unpack)))
    argspec = get_argspec(fn)
    parameter_type = computation_types.to_type(parameter_type)
    if not parameter_type:
        if is_argspec_compatible_with_types(argspec):
            # Deliberate wrapping to isolate the caller from `fn`, e.g., to prevent
            # the caller from mistakenly specifying args that match fn's defaults.
            return lambda: fn()  # pylint: disable=unnecessary-lambda
        else:
            raise TypeError(
                'The argspec {} of the supplied function cannot be interpreted as a '
                'body of a no-parameter computation.'.format(str(argspec)))
    else:
        if infer_unpack_needed(fn, parameter_type, unpack):
            arg_types, kwarg_types = unpack_args_from_tuple(parameter_type)

            def _unpack_and_call(fn, arg_types, kwarg_types, arg):
                """An interceptor function that unpacks 'arg' before calling `fn`.

        The function verifies the actual parameters before it forwards the
        call as a last-minute check.

        Args:
          fn: The function or defun to invoke.
          arg_types: The list of positional argument types (guaranteed to all be
            instances of computation_types.Types).
          kwarg_types: The dictionary of keyword argument types (guaranteed to
            all be instances of computation_types.Types).
          arg: The argument to unpack.

        Returns:
          The result of invoking `fn` on the unpacked arguments.

        Raises:
          TypeError: if types don't match.
        """
                py_typecheck.check_type(
                    arg, (anonymous_tuple.AnonymousTuple, value_base.Value))
                args = []
                for idx, expected_type in enumerate(arg_types):
                    element_value = arg[idx]
                    actual_type = type_utils.infer_type(element_value)
                    if not type_utils.is_assignable_from(
                            expected_type, actual_type):
                        raise TypeError(
                            'Expected element at position {} to be '
                            'of type {}, found {}.'.format(
                                idx, str(expected_type), str(actual_type)))
                    if isinstance(element_value,
                                  anonymous_tuple.AnonymousTuple):
                        element_value = type_utils.convert_to_py_container(
                            element_value, expected_type)
                    args.append(element_value)
                kwargs = {}
                for name, expected_type in six.iteritems(kwarg_types):
                    element_value = getattr(arg, name)
                    actual_type = type_utils.infer_type(element_value)
                    if not type_utils.is_assignable_from(
                            expected_type, actual_type):
                        raise TypeError('Expected element named {} to be '
                                        'of type {}, found {}.'.format(
                                            name, str(expected_type),
                                            str(actual_type)))
                    if type_utils.is_anon_tuple_with_py_container(
                            element_value, expected_type):
                        element_value = type_utils.convert_to_py_container(
                            element_value, expected_type)
                    kwargs[name] = element_value
                return fn(*args, **kwargs)

            # TODO(b/132888123): Consider other options to avoid possible bugs here.
            try:
                (fn, arg_types, kwarg_types)
            except NameError:
                raise AssertionError('Args to be bound must be in scope.')
            return lambda arg: _unpack_and_call(fn, arg_types, kwarg_types, arg
                                                )
        else:
            # An interceptor function that verifies the actual parameter before it
            # forwards the call as a last-minute check.
            def _call(fn, parameter_type, arg):
                arg_type = type_utils.infer_type(arg)
                if not type_utils.is_assignable_from(parameter_type, arg_type):
                    raise TypeError(
                        'Expected an argument of type {}, found {}.'.format(
                            str(parameter_type), str(arg_type)))
                if type_utils.is_anon_tuple_with_py_container(
                        arg, parameter_type):
                    arg = type_utils.convert_to_py_container(
                        arg, parameter_type)
                return fn(arg)

            # TODO(b/132888123): Consider other options to avoid possible bugs here.
            try:
                (fn, parameter_type)
            except NameError:
                raise AssertionError('Args to be bound must be in scope.')
            return lambda arg: _call(fn, parameter_type, arg)
Ejemplo n.º 20
0
def to_value(
    arg: Any,
    type_spec,
    context_stack: context_stack_base.ContextStack,
    parameter_type_hint=None,
) -> ValueImpl:
    """Converts the argument into an instance of `tff.Value`.

  The types of non-`tff.Value` arguments that are currently convertible to
  `tff.Value` include the following:

  * Lists, tuples, `structure.Struct`s, named tuples, and dictionaries, all
    of which are converted into instances of `tff.Tuple`.
  * Placement literals, converted into instances of `tff.Placement`.
  * Computations.
  * Python constants of type `str`, `int`, `float`, `bool`
  * Numpy objects inherting from `np.ndarray` or `np.generic` (the parent
    of numpy scalar types)

  Args:
    arg: Either an instance of `tff.Value`, or an argument convertible to
      `tff.Value`. The argument must not be `None`.
    type_spec: An optional `computation_types.Type` or value convertible to it
      by `computation_types.to_type` which specifies the desired type signature
      of the resulting value. This allows for disambiguating the target type
      (e.g., when two TFF types can be mapped to the same Python
      representations), or `None` if none available, in which case TFF tries to
      determine the type of the TFF value automatically.
    context_stack: The context stack to use.
    parameter_type_hint: An optional `computation_types.Type` or value
      convertible to it by `computation_types.to_type` which specifies an
      argument type to use in the case that `arg` is a
      `function_utils.PolymorphicFunction`.

  Returns:
    An instance of `tff.Value` corresponding to the given `arg`, and of TFF type
    matching the `type_spec` if specified (not `None`).

  Raises:
    TypeError: if `arg` is of an unsupported type, or of a type that does not
      match `type_spec`. Raises explicit error message if TensorFlow constructs
      are encountered, as TensorFlow code should be sealed away from TFF
      federated context.
  """
    py_typecheck.check_type(context_stack, context_stack_base.ContextStack)
    _check_symbol_binding_context(context_stack.current)
    if type_spec is not None:
        type_spec = computation_types.to_type(type_spec)
    if isinstance(arg, ValueImpl):
        result = arg
    elif isinstance(arg, building_blocks.ComputationBuildingBlock):
        result = ValueImpl(arg, context_stack)
    elif isinstance(arg, placement_literals.PlacementLiteral):
        result = ValueImpl(building_blocks.Placement(arg), context_stack)
    elif isinstance(
            arg,
        (computation_base.Computation, function_utils.PolymorphicFunction)):
        if isinstance(arg, function_utils.PolymorphicFunction):
            if parameter_type_hint is None:
                raise TypeError(
                    'Polymorphic computations cannot be converted to TFF values '
                    'without a type hint. Consider explicitly specifying the '
                    'argument types of a computation before passing it to a '
                    'function that requires a TFF value (such as a TFF intrinsic '
                    'like `federated_map`). If you are a TFF developer and think '
                    'this should be supported, consider providing `parameter_type_hint` '
                    'as an argument to the encompassing `to_value` conversion.'
                )
            parameter_type_hint = computation_types.to_type(
                parameter_type_hint)
            arg = arg.fn_for_argument_type(parameter_type_hint)
        py_typecheck.check_type(arg, computation_base.Computation)
        result = ValueImpl(arg.to_compiled_building_block(), context_stack)
    elif type_spec is not None and type_spec.is_sequence():
        result = _wrap_sequence_as_value(arg, type_spec.element, context_stack)
    elif isinstance(arg, structure.Struct):
        result = ValueImpl(
            building_blocks.Struct([
                (k, ValueImpl.get_comp(to_value(v, None, context_stack)))
                for k, v in structure.iter_elements(arg)
            ]), context_stack)
    elif py_typecheck.is_named_tuple(arg):
        items = arg._asdict().items()
        result = _dictlike_items_to_value(items, context_stack, type(arg))
    elif py_typecheck.is_attrs(arg):
        items = attr.asdict(arg,
                            dict_factory=collections.OrderedDict,
                            recurse=False).items()
        result = _dictlike_items_to_value(items, context_stack, type(arg))
    elif isinstance(arg, dict):
        if isinstance(arg, collections.OrderedDict):
            items = arg.items()
        else:
            items = sorted(arg.items())
        result = _dictlike_items_to_value(items, context_stack, type(arg))
    elif isinstance(arg, (tuple, list)):
        result = ValueImpl(
            building_blocks.Struct([
                ValueImpl.get_comp(to_value(x, None, context_stack))
                for x in arg
            ], type(arg)), context_stack)
    elif isinstance(arg, tensorflow_utils.TENSOR_REPRESENTATION_TYPES):
        result = _wrap_constant_as_value(arg, context_stack)
    elif isinstance(arg, (tf.Tensor, tf.Variable)):
        raise TypeError(
            'TensorFlow construct {} has been encountered in a federated '
            'context. TFF does not support mixing TF and federated orchestration '
            'code. Please wrap any TensorFlow constructs with '
            '`tff.tf_computation`.'.format(arg))
    else:
        raise TypeError(
            'Unable to interpret an argument of type {} as a TFF value.'.
            format(py_typecheck.type_string(type(arg))))
    py_typecheck.check_type(result, ValueImpl)
    if (type_spec is not None
            and not type_spec.is_assignable_from(result.type_signature)):
        raise TypeError(
            'The supplied argument maps to TFF type {}, which is incompatible with '
            'the requested type {}.'.format(result.type_signature, type_spec))
    return result
Ejemplo n.º 21
0
    def test_build_tf_computations_for_sum(self, encoder_constructor):
        # Tests that the partial computations have matching relevant input-output
        # signatures.
        value_spec = tf.TensorSpec((20, ), tf.float32)
        encoder = te.encoders.as_gather_encoder(encoder_constructor(),
                                                value_spec)

        _, state_type = encoding_utils._build_initial_state_tf_computation(
            encoder)
        value_type = computation_types.to_type(value_spec)
        nest_encoder = encoding_utils._build_tf_computations_for_gather(
            state_type, value_type, encoder)

        self.assertEqual(state_type,
                         nest_encoder.get_params_fn.type_signature.parameter)
        encode_params_type = nest_encoder.get_params_fn.type_signature.result[
            0]
        decode_before_sum_params_type = nest_encoder.get_params_fn.type_signature.result[
            1]
        decode_after_sum_params_type = nest_encoder.get_params_fn.type_signature.result[
            2]

        self.assertEqual(value_type,
                         nest_encoder.encode_fn.type_signature.parameter[0])
        self.assertEqual(encode_params_type,
                         nest_encoder.encode_fn.type_signature.parameter[1])
        self.assertEqual(decode_before_sum_params_type,
                         nest_encoder.encode_fn.type_signature.parameter[2])
        state_update_tensors_type = nest_encoder.encode_fn.type_signature.result[
            2]

        accumulator_type = nest_encoder.zero_fn.type_signature.result
        self.assertEqual(state_update_tensors_type,
                         accumulator_type.state_update_tensors)

        self.assertEqual(
            accumulator_type,
            nest_encoder.accumulate_fn.type_signature.parameter[0])
        self.assertEqual(
            nest_encoder.encode_fn.type_signature.result,
            nest_encoder.accumulate_fn.type_signature.parameter[1])
        self.assertEqual(accumulator_type,
                         nest_encoder.accumulate_fn.type_signature.result)
        self.assertEqual(accumulator_type,
                         nest_encoder.merge_fn.type_signature.parameter[0])
        self.assertEqual(accumulator_type,
                         nest_encoder.merge_fn.type_signature.parameter[1])
        self.assertEqual(accumulator_type,
                         nest_encoder.merge_fn.type_signature.result)
        self.assertEqual(accumulator_type,
                         nest_encoder.report_fn.type_signature.parameter)
        self.assertEqual(accumulator_type,
                         nest_encoder.report_fn.type_signature.result)

        self.assertEqual(
            accumulator_type.values,
            nest_encoder.decode_after_sum_fn.type_signature.parameter[0])
        self.assertEqual(
            decode_after_sum_params_type,
            nest_encoder.decode_after_sum_fn.type_signature.parameter[1])
        self.assertEqual(
            value_type, nest_encoder.decode_after_sum_fn.type_signature.result)

        self.assertEqual(
            state_type,
            nest_encoder.update_state_fn.type_signature.parameter[0])
        self.assertEqual(
            state_update_tensors_type,
            nest_encoder.update_state_fn.type_signature.parameter[1])
        self.assertEqual(state_type,
                         nest_encoder.update_state_fn.type_signature.result)
Ejemplo n.º 22
0
def serialize_py_fn_as_tf_computation(target, parameter_type, context_stack):
  """Serializes the 'target' as a TF computation with a given parameter type.

  Args:
    target: The entity to convert into and serialize as a TF computation. This
      can currently only be a Python function. In the future, we will add here
      support for serializing the various kinds of non-eager and eager
      functions, and eventually aim at full support for and compliance with TF
      2.0. This function is currently required to declare either zero parameters
      if `parameter_type` is `None`, or exactly one parameter if it's not
      `None`.  The nested structure of this parameter must correspond to the
      structure of the 'parameter_type'. In the future, we may support targets
      with multiple args/keyword args (to be documented in the API and
      referenced from here).
    parameter_type: The parameter type specification if the target accepts a
      parameter, or `None` if the target doesn't declare any parameters. Either
      an instance of `types.Type`, or something that's convertible to it by
      `types.to_type()`.
    context_stack: The context stack to use.

  Returns:
    The constructed `pb.Computation` instance with the `pb.TensorFlow` variant
      set.

  Raises:
    TypeError: If the arguments are of the wrong types.
    ValueError: If the signature of the target is not compatible with the given
      parameter type.
  """
  # TODO(b/113112108): Support a greater variety of target type signatures,
  # with keyword args or multiple args corresponding to elements of a tuple.
  # Document all accepted forms with examples in the API, and point to there
  # from here.

  py_typecheck.check_type(target, types.FunctionType)
  py_typecheck.check_type(context_stack, context_stack_base.ContextStack)
  parameter_type = computation_types.to_type(parameter_type)
  argspec = inspect.getargspec(target)  # pylint: disable=deprecated-method

  with tf.Graph().as_default() as graph:
    args = []
    if parameter_type:
      if len(argspec.args) != 1:
        raise ValueError(
            'Expected the target to declare exactly one parameter, '
            'found {}.'.format(repr(argspec.args)))
      parameter_name = argspec.args[0]
      parameter_value, parameter_binding = graph_utils.stamp_parameter_in_graph(
          parameter_name, parameter_type, graph)
      args.append(parameter_value)
    else:
      if argspec.args:
        raise ValueError(
            'Expected the target to declare no parameters, found {}.'.format(
                repr(argspec.args)))
      parameter_binding = None
    context = tf_computation_context.TensorFlowComputationContext(graph)
    with context_stack.install(context):
      result = target(*args)

      # TODO(b/122081673): This needs to change for TF 2.0. We may also
      # want to allow the person creating a tff.tf_computation to specify
      # a different initializer; e.g., if it is known that certain
      # variables will be assigned immediately to arguments of the function,
      # then it is wasteful to initialize them before this.
      #
      # The following is a bit of a work around: the collections below may
      # contain variables more than once, hence we throw into a set. TFF needs
      # to ensure all variables are initialized, but not all variables are
      # always in the collections we expect. tff.learning._KerasModel tries to
      # pull Keras variables (that may or may not be in GLOBAL_VARIABLES) into
      # TFF_MODEL_VARIABLES for now.
      all_variables = set(
          tf.global_variables() + tf.local_variables() +
          tf.get_collection(graph_keys.GraphKeys.VARS_FOR_TFF_TO_INITIALIZE))
      if all_variables:
        # Use a readable but not-too-long name for the init_op.
        name = 'init_op_for_' + '_'.join(
            [v.name.replace(':0', '') for v in all_variables])
        if len(name) > 50:
          name = 'init_op_for_{}_variables'.format(len(all_variables))
        with tf.control_dependencies(context.init_ops):
          # Before running the main new init op, run any initializers for sub-
          # computations from context.init_ops. Variables from import_graph_def
          # will not make it into the global collections, and so will not be
          # initialized without this code path.
          init_op_name = tf.initializers.variables(
              all_variables, name=name).name
      elif context.init_ops:
        init_op_name = tf.group(
            *context.init_ops, name='subcomputation_init_ops').name
      else:
        init_op_name = None

    result_type, result_binding = graph_utils.capture_result_from_graph(
        result, graph)

  return pb.Computation(
      type=pb.Type(
          function=pb.FunctionType(
              parameter=type_serialization.serialize_type(parameter_type),
              result=type_serialization.serialize_type(result_type))),
      tensorflow=pb.TensorFlow(
          graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()),
          parameter=parameter_binding,
          result=result_binding,
          initialize_op=init_op_name))
Ejemplo n.º 23
0
 def test_tf_type(self):
     s = tf.int32
     t = computation_types.to_type(s)
     self.assertIsInstance(t, computation_types.TensorType)
     self.assertEqual(str(t), 'int32')
Ejemplo n.º 24
0
def embed_tensorflow_computation(comp, type_spec=None, device=None):
  """Embeds a TensorFlow computation for use in the eager context.

  Args:
    comp: An instance of `pb.Computation`.
    type_spec: An optional `tff.Type` instance or something convertible to it.
    device: An optional device name.

  Returns:
    Either a one-argument or a zero-argument callable that executes the
    computation in eager mode.

  Raises:
    TypeError: If arguments are of the wrong types, e.g., in `comp` is not a
      TensorFlow computation.
  """
  # TODO(b/134543154): Decide whether this belongs in `tensorflow_utils.py`
  # since it deals exclusively with eager mode. Incubate here, and potentially
  # move there, once stable.

  py_typecheck.check_type(comp, pb.Computation)
  comp_type = type_serialization.deserialize_type(comp.type)
  type_spec = computation_types.to_type(type_spec)
  if type_spec is not None:
    if not type_utils.are_equivalent_types(type_spec, comp_type):
      raise TypeError('Expected a computation of type {}, got {}.'.format(
          type_spec, comp_type))
  else:
    type_spec = comp_type
  which_computation = comp.WhichOneof('computation')
  if which_computation != 'tensorflow':
    raise TypeError('Expected a TensorFlow computation, found {}.'.format(
        which_computation))

  if isinstance(type_spec, computation_types.FunctionType):
    param_type = type_spec.parameter
    result_type = type_spec.result
  else:
    param_type = None
    result_type = type_spec

  if param_type is not None:
    input_tensor_names = tensorflow_utils.extract_tensor_names_from_binding(
        comp.tensorflow.parameter)
  else:
    input_tensor_names = []

  output_tensor_names = tensorflow_utils.extract_tensor_names_from_binding(
      comp.tensorflow.result)

  def function_to_wrap(*args):  # pylint: disable=missing-docstring
    if len(args) != len(input_tensor_names):
      raise RuntimeError('Expected {} arguments, found {}.'.format(
          len(input_tensor_names), len(args)))
    graph_def = serialization_utils.unpack_graph_def(comp.tensorflow.graph_def)
    init_op = comp.tensorflow.initialize_op
    if init_op:
      graph_def = tensorflow_utils.add_control_deps_for_init_op(
          graph_def, init_op)

    def _import_fn():
      return tf.import_graph_def(
          graph_merge.uniquify_shared_names(graph_def),
          input_map=dict(list(zip(input_tensor_names, args))),
          return_elements=output_tensor_names)

    if device is not None:
      with tf.device(device):
        return _import_fn()
    else:
      return _import_fn()

  signature = []
  param_fns = []
  if param_type is not None:
    for spec in anonymous_tuple.flatten(type_spec.parameter):
      if isinstance(spec, computation_types.TensorType):
        signature.append(tf.TensorSpec(spec.shape, spec.dtype))
        param_fns.append(lambda x: x)
      else:
        py_typecheck.check_type(spec, computation_types.SequenceType)
        signature.append(tf.TensorSpec([], tf.variant))
        param_fns.append(tf.data.experimental.to_variant)

  wrapped_fn = tf.compat.v1.wrap_function(function_to_wrap, signature)

  result_fns = []
  for spec in anonymous_tuple.flatten(result_type):
    if isinstance(spec, computation_types.TensorType):
      result_fns.append(lambda x: x)
    else:
      py_typecheck.check_type(spec, computation_types.SequenceType)
      structure = type_utils.type_to_tf_structure(spec.element)

      def fn(x, structure=structure):
        return tf.data.experimental.from_variant(x, structure)

      result_fns.append(fn)

  def _fn_to_return(arg, param_fns, wrapped_fn):  # pylint:disable=missing-docstring
    param_elements = []
    if arg is not None:
      arg_parts = anonymous_tuple.flatten(arg)
      if len(arg_parts) != len(param_fns):
        raise RuntimeError('Expected {} arguments, found {}.'.format(
            len(param_fns), len(arg_parts)))
      for arg_part, param_fn in zip(arg_parts, param_fns):
        param_elements.append(param_fn(arg_part))
    result_parts = wrapped_fn(*param_elements)
    result_elements = []
    for result_part, result_fn in zip(result_parts, result_fns):
      result_elements.append(result_fn(result_part))
    return anonymous_tuple.pack_sequence_as(result_type, result_elements)

  fn_to_return = lambda arg, p=param_fns, w=wrapped_fn: _fn_to_return(arg, p, w)

  if device is not None:
    old_fn_to_return = fn_to_return

    # pylint: disable=function-redefined
    def fn_to_return(x):
      with tf.device(device):
        return old_fn_to_return(x)

    # pylint: enable=function-redefined

  if param_type is not None:
    return lambda arg: fn_to_return(arg)  # pylint: disable=unnecessary-lambda
  else:
    return lambda: fn_to_return(None)
Ejemplo n.º 25
0
 def test_tf_type_and_shape(self):
     s = (tf.int32, [10])
     t = computation_types.to_type(s)
     self.assertIsInstance(t, computation_types.TensorType)
     self.assertEqual(str(t), 'int32[10]')
Ejemplo n.º 26
0
 def test_is_signature_compatible_with_types_false(self, signature, args,
                                                   kwargs):
   self.assertFalse(
       function_utils.is_signature_compatible_with_types(
           signature, *[computation_types.to_type(a) for a in args],
           **{k: computation_types.to_type(v) for k, v in kwargs.items()}))
Ejemplo n.º 27
0
 def test_list_of_tf_types(self):
     s = [tf.int32, tf.bool]
     t = computation_types.to_type(s)
     self.assertIsInstance(t, computation_types.StructWithPythonType)
     self.assertEqual(str(t), '<int32,bool>')
Ejemplo n.º 28
0
def build_federated_evaluation(model_fn,
                               use_experimental_simulation_loop: bool = False):
  """Builds the TFF computation for federated evaluation of the given model.

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`. This method
      must *not* capture TensorFlow tensors or variables and use them. The model
      must be constructed entirely from scratch on each invocation, returning
      the same pre-constructed model each call will result in an error.
    use_experimental_simulation_loop: Controls the reduce loop function for
        input dataset. An experimental reduce loop is used for simulation.

  Returns:
    A federated computation (an instance of `tff.Computation`) that accepts
    model parameters and federated data, and returns the evaluation metrics
    as aggregated by `tff.learning.Model.federated_output_computation`.
  """
  # Construct the model first just to obtain the metadata and define all the
  # types needed to define the computations that follow.
  # TODO(b/124477628): Ideally replace the need for stamping throwaway models
  # with some other mechanism.
  with tf.Graph().as_default():
    model = model_fn()
    model_weights_type = model_utils.weights_type_from_model(model)
    batch_type = computation_types.to_type(model.input_spec)

  @computations.tf_computation(model_weights_type,
                               computation_types.SequenceType(batch_type))
  def client_eval(incoming_model_weights, dataset):
    """Returns local outputs after evaluting `model_weights` on `dataset`."""
    model = model_utils.enhance(model_fn())

    @tf.function
    def _tf_client_eval(incoming_model_weights, dataset):
      """Evaluation TF work."""
      tf_computation_utils.assign(model.weights, incoming_model_weights)

      def reduce_fn(prev_loss, batch):
        model_output = model.forward_pass(batch, training=False)
        return prev_loss + tf.cast(model_output.loss, tf.float64)

      dataset_reduce_fn = dataset_reduce.build_dataset_reduce_fn(
          use_experimental_simulation_loop)
      dataset_reduce_fn(
          reduce_fn=reduce_fn,
          dataset=dataset,
          initial_state_fn=lambda: tf.constant(0, dtype=tf.float64))

      return collections.OrderedDict([('local_outputs',
                                       model.report_local_outputs())])

    return _tf_client_eval(incoming_model_weights, dataset)

  @computations.federated_computation(
      computation_types.FederatedType(model_weights_type, placements.SERVER),
      computation_types.FederatedType(
          computation_types.SequenceType(batch_type), placements.CLIENTS))
  def server_eval(server_model_weights, federated_dataset):
    client_outputs = intrinsics.federated_map(client_eval, [
        intrinsics.federated_broadcast(server_model_weights), federated_dataset
    ])
    return model.federated_output_computation(client_outputs.local_outputs)

  return server_eval
Ejemplo n.º 29
0
 def test_singleton_named_tf_type(self):
     s = ('a', tf.int32)
     t = computation_types.to_type(s)
     self.assertIsInstance(t, computation_types.StructWithPythonType)
     self.assertIs(t.python_container, tuple)
     self.assertEqual(str(t), '<a=int32>')
Ejemplo n.º 30
0
 async def create_value(self, value, type_spec=None):
     type_spec = computation_types.to_type(type_spec)
     py_typecheck.check_type(type_spec, computation_types.Type)
     if isinstance(value, intrinsic_defs.IntrinsicDef):
         if not type_utils.is_concrete_instance_of(type_spec,
                                                   value.type_signature):  # pytype: disable=attribute-error
             raise TypeError(
                 'Incompatible type {} used with intrinsic {}.'.format(
                     type_spec, value.uri))  # pytype: disable=attribute-error
         else:
             return CompositeValue(value, type_spec)
     elif isinstance(value, pb.Computation):
         which_computation = value.WhichOneof('computation')
         if which_computation in ['tensorflow', 'lambda']:
             return CompositeValue(value, type_spec)
         elif which_computation == 'intrinsic':
             intr = intrinsic_defs.uri_to_intrinsic_def(value.intrinsic.uri)
             if intr is None:
                 raise ValueError(
                     'Encountered an unrecognized intrinsic "{}".'.format(
                         value.intrinsic.uri))
             py_typecheck.check_type(intr, intrinsic_defs.IntrinsicDef)
             return await self.create_value(intr, type_spec)
         else:
             raise NotImplementedError(
                 'Unimplemented computation type {}.'.format(
                     which_computation))
     elif isinstance(type_spec, computation_types.NamedTupleType):
         value_tuple = anonymous_tuple.from_container(value)
         items = await asyncio.gather(*[
             self.create_value(v, t)
             for v, t in zip(value_tuple, type_spec)
         ])
         type_elemnents_iter = anonymous_tuple.iter_elements(type_spec)
         return self.create_tuple(
             anonymous_tuple.AnonymousTuple(
                 (k, i) for (k, _), i in zip(type_elemnents_iter, items)))
     elif isinstance(type_spec, computation_types.FederatedType):
         if type_spec.placement == placement_literals.SERVER:
             if type_spec.all_equal:
                 return CompositeValue(
                     await self._parent_executor.create_value(
                         value, type_spec.member), type_spec)
             else:
                 raise ValueError(
                     'A non-all_equal value on the server is unexpected.')
         elif type_spec.placement == placement_literals.CLIENTS:
             if type_spec.all_equal:
                 return CompositeValue(
                     await asyncio.gather(*[
                         c.create_value(value, type_spec)
                         for c in self._child_executors
                     ]), type_spec)
             else:
                 py_typecheck.check_type(value, list)
                 if self._cardinalities is None:
                     self._cardinalities = asyncio.ensure_future(
                         self._get_cardinalities())
                 cardinalities = await self._cardinalities
                 py_typecheck.check_len(cardinalities,
                                        len(self._child_executors))
                 count = sum(cardinalities)
                 py_typecheck.check_len(value, count)
                 result = []
                 offset = 0
                 for c, n in zip(self._child_executors, cardinalities):
                     new_offset = offset + n
                     # The slice opporator is not supported on all the types `value`
                     # supports.
                     # pytype: disable=unsupported-operands
                     result.append(
                         c.create_value(value[offset:new_offset],
                                        type_spec))
                     # pytype: enable=unsupported-operands
                     offset = new_offset
                 return CompositeValue(await asyncio.gather(*result),
                                       type_spec)
         else:
             raise ValueError('Unexpected placement {}.'.format(
                 type_spec.placement))
     else:
         return CompositeValue(
             await self._parent_executor.create_value(value, type_spec),
             type_spec)