Exemple #1
0
def extract_nodes_consuming(tree, predicate):
    """Returns the set of AST nodes which consume nodes matching `predicate`.

  Notice we adopt the convention that a node which itself satisfies the
  predicate is in this set.

  Args:
    tree: Instance of `building_blocks.ComputationBuildingBlock` to view as an
      abstract syntax tree, and construct the set of nodes in this tree having a
      dependency on nodes matching `predicate`; that is, the set of nodes whose
      value depends on evaluating nodes matching `predicate`.
    predicate: One-arg callable, accepting arguments of type
      `building_blocks.ComputationBuildingBlock` and returning a `bool`
      indicating match or mismatch with the desired pattern.

  Returns:
    A `set` of `building_blocks.ComputationBuildingBlock` instances
    representing the nodes in `tree` dependent on nodes matching `predicate`.
  """
    py_typecheck.check_type(tree, building_blocks.ComputationBuildingBlock)
    py_typecheck.check_callable(predicate)
    dependent_nodes = set()

    def _are_children_in_dependent_set(comp, symbol_tree):
        """Checks if the dependencies of `comp` are present in `dependent_nodes`."""
        if isinstance(
                comp,
            (building_blocks.Intrinsic, building_blocks.Data,
             building_blocks.Placement, building_blocks.CompiledComputation)):
            return False
        elif isinstance(comp, building_blocks.Lambda):
            return comp.result in dependent_nodes
        elif isinstance(comp, building_blocks.Block):
            return any(x[1] in dependent_nodes
                       for x in comp.locals) or comp.result in dependent_nodes
        elif isinstance(comp, building_blocks.Tuple):
            return any(x in dependent_nodes for x in comp)
        elif isinstance(comp, building_blocks.Selection):
            return comp.source in dependent_nodes
        elif isinstance(comp, building_blocks.Call):
            return comp.function in dependent_nodes or comp.argument in dependent_nodes
        elif isinstance(comp, building_blocks.Reference):
            return _is_reference_dependent(comp, symbol_tree)

    def _is_reference_dependent(comp, symbol_tree):
        payload = symbol_tree.get_payload_with_name(comp.name)
        if payload is None:
            return False
        # The postorder traversal ensures that we process any
        # bindings before we process the reference to those bindings
        return payload.value in dependent_nodes

    def _populate_dependent_set(comp, symbol_tree):
        """Populates `dependent_nodes` with all nodes dependent on `predicate`."""
        if predicate(comp):
            dependent_nodes.add(comp)
        elif _are_children_in_dependent_set(comp, symbol_tree):
            dependent_nodes.add(comp)
        return comp, False

    symbol_tree = transformation_utils.SymbolTree(
        transformation_utils.ReferenceCounter)
    transformation_utils.transform_postorder_with_symbol_bindings(
        tree, _populate_dependent_set, symbol_tree)
    return dependent_nodes
def to_representation_for_type(value, type_spec=None, device=None):
  """Verifies or converts the `value` to an eager objct matching `type_spec`.

  WARNING: This function is only partially implemented. It does not support
  data sets at this point.

  The output of this function is always an eager tensor, eager dataset, a
  representation of a TensorFlow computtion, or a nested structure of those
  that matches `type_spec`, and when `device` has been specified, everything
  is placed on that device on a best-effort basis.

  TensorFlow computations are represented here as zero- or one-argument Python
  callables that accept their entire argument bundle as a single Python object.

  Args:
    value: The raw representation of a value to compare against `type_spec` and
      potentially to be converted.
    type_spec: An instance of `tff.Type`, can be `None` for values that derive
      from `typed_object.TypedObject`.
    device: The optional device to place the value on (for tensor-level values).

  Returns:
    Either `value` itself, or a modified version of it.

  Raises:
    TypeError: If the `value` is not compatible with `type_spec`.
  """
  if device is not None:
    py_typecheck.check_type(device, six.string_types)
    with tf.device(device):
      return to_representation_for_type(value, type_spec=type_spec, device=None)
  type_spec = type_utils.reconcile_value_with_type_spec(value, type_spec)
  if isinstance(value, EagerValue):
    return value.internal_representation
  if isinstance(value, executor_value_base.ExecutorValue):
    raise TypeError(
        'Cannot accept a value embedded within a non-eager executor.')
  if isinstance(value, computation_base.Computation):
    return to_representation_for_type(
        computation_impl.ComputationImpl.get_proto(value), type_spec, device)
  if isinstance(value, pb.Computation):
    return embed_tensorflow_computation(value, type_spec, device)
  if isinstance(type_spec, computation_types.TensorType):
    if not isinstance(value, tf.Tensor):
      if isinstance(value, np.ndarray):
        value = tf.constant(value, dtype=type_spec.dtype)
      else:
        value = tf.constant(value, dtype=type_spec.dtype, shape=type_spec.shape)
    value_type = (
        computation_types.TensorType(value.dtype.base_dtype, value.shape))
    if not type_utils.is_assignable_from(type_spec, value_type):
      raise TypeError(
          'The apparent type {} of a tensor {} does not match the expected '
          'type {}.'.format(str(value_type), str(value), str(type_spec)))
    return value
  elif isinstance(type_spec, computation_types.NamedTupleType):
    type_elem = anonymous_tuple.to_elements(type_spec)
    value_elem = (
        anonymous_tuple.to_elements(anonymous_tuple.from_container(value)))
    result_elem = []
    if len(type_elem) != len(value_elem):
      raise TypeError('Expected a {}-element tuple, found {} elements.'.format(
          str(len(type_elem)), str(len(value_elem))))
    for (t_name, el_type), (v_name, el_val) in zip(type_elem, value_elem):
      if t_name != v_name:
        raise TypeError(
            'Mismatching element names in type vs. value: {} vs. {}.'.format(
                t_name, v_name))
      el_repr = to_representation_for_type(el_val, el_type, device)
      result_elem.append((t_name, el_repr))
    return anonymous_tuple.AnonymousTuple(result_elem)
  elif isinstance(type_spec, computation_types.SequenceType):
    if isinstance(value, list):
      value = graph_utils.make_data_set_from_elements(None, value,
                                                      type_spec.element)
    py_typecheck.check_type(
        value,
        (tf.data.Dataset, tf.compat.v1.data.Dataset, tf.compat.v2.data.Dataset))
    element_type = type_utils.tf_dtypes_and_shapes_to_type(
        tf.compat.v1.data.get_output_types(value),
        tf.compat.v1.data.get_output_shapes(value))
    value_type = computation_types.SequenceType(element_type)
    type_utils.check_assignable_from(type_spec, value_type)
    return value
  else:
    raise TypeError('Unexpected type {}.'.format(str(type_spec)))
Exemple #3
0
 def __init__(self, identifier):
     py_typecheck.check_type(identifier, str)
     self._identifier = identifier
Exemple #4
0
def zeroing_factory(zeroing_norm: Union[float,
                                        estimation_process.EstimationProcess],
                    inner_agg_factory: factory.AggregationFactory,
                    norm_order: float = math.inf) -> factory.AggregationFactory:
  """Creates an aggregation factory to perform zeroing.

  The created `tff.templates.AggregationProcess` zeroes out any values whose
  norm is greater than that determined by the provided `zeroing_norm`, before
  aggregating the values as specified by `inner_agg_factory`. Note that for
  weighted aggregation if some value is zeroed, the weight is unchanged. So for
  example if you have a zeroed weighted mean and a lot of zeroing occurs, the
  average will tend to be pulled toward zero. This is for consistency between
  weighted and unweighted aggregation

  The provided `zeroing_norm` can either be a constant (for fixed norm), or an
  instance of `tff.templates.EstimationProcess` (for adaptive norm). If it is an
  estimation process, the value returned by its `report` method will be used as
  the zeroing norm. Its `next` method needs to accept a scalar float32 at
  clients, corresponding to the norm of value being aggregated. The process can
  thus adaptively determine the zeroing norm based on the set of aggregated
  values. For example if a `tff.aggregators.PrivateQuantileEstimationProcess` is
  used, the zeroing norm will be an estimate of a quantile of the norms of the
  values being aggregated.

  The returned `AggregationFactory` takes its weightedness
  (`UnweightedAggregationFactory` vs. `WeightedAggregationFactory`) from
  `inner_agg_factory`.

  Args:
    zeroing_norm: Either a float (for fixed norm) or an `EstimationProcess` (for
      adaptive norm) that specifies the norm over which the values should be
      zeroed.
    inner_agg_factory: A factory specifying the type of aggregation to be done
      after zeroing.
    norm_order: A float for the order of the norm. Must be 1., 2., or infinity.

  Returns:
    An aggregation factory to perform L2 clipping.
  """

  py_typecheck.check_type(norm_order, float)
  if not (norm_order in [1.0, 2.0] or math.isinf(norm_order)):
    raise ValueError('norm_order must be 1.0, 2.0 or infinity')

  def make_zero_fn(value_type):
    """Creates a zeroing function for the value_type."""

    @computations.tf_computation(value_type, NORM_TF_TYPE)
    def zero_fn(value, zeroing_norm):
      if norm_order == 1.0:
        global_norm = _global_l1_norm(value)
      elif norm_order == 2.0:
        global_norm = tf.linalg.global_norm(tf.nest.flatten(value))
      else:
        assert math.isinf(norm_order)
        global_norm = _global_inf_norm(value)
      should_zero = (global_norm > zeroing_norm)
      zeroed_value = tf.cond(
          should_zero, lambda: tf.nest.map_structure(tf.zeros_like, value),
          lambda: value)
      was_zeroed = tf.cast(should_zero, COUNT_TF_TYPE)
      return zeroed_value, global_norm, was_zeroed

    return zero_fn

  return _make_wrapper(zeroing_norm, inner_agg_factory, make_zero_fn, 'zero')
Exemple #5
0
def _check_value_type(value_type):
  py_typecheck.check_type(value_type, factory.ValueType.__args__)
  if not type_analysis.is_structure_of_floats(value_type):
    raise TypeError(f'All values in provided value_type must be of floating '
                    f'dtype. Provided value_type: {value_type}')
def append_to_list_structure_for_element_type_spec(nested, value, type_spec):
    """Adds an element `value` to `nested` lists for `type_spec`.

  This function appends tensor-level constituents of an element `value` to the
  lists created by `make_empty_list_structure_for_element_type_spec`. The
  nested structure of `value` must match that created by the above function,
  and consistent with `type_spec`.

  Args:
    nested: Output of `make_empty_list_structure_for_element_type_spec`.
    value: A value (Python object) that a hierarchical structure of dictionary,
      list, and other containers holding tensor-like items that matches the
      hierarchy of `type_spec`.
    type_spec: An instance of `tff.Type` or something convertible to it, as in
      `make_empty_list_structure_for_element_type_spec`.

  Raises:
    TypeError: If the `type_spec` is not of a form described above, or the value
      is not of a type compatible with `type_spec`.
  """
    if value is None:
        return
    type_spec = computation_types.to_type(type_spec)
    # TODO(b/113116813): This could be made more efficient, but for now we won't
    # need to worry about it as this is an odd corner case.
    if isinstance(value, structure.Struct):
        elements = structure.to_elements(value)
        if all(k is not None for k, _ in elements):
            value = collections.OrderedDict(elements)
        elif all(k is None for k, _ in elements):
            value = tuple([v for _, v in elements])
        else:
            raise TypeError(
                'Expected an anonymous tuple to either have all elements named or '
                'all unnamed, got {}.'.format(value))
    if type_spec.is_tensor():
        py_typecheck.check_type(nested, list)
        # Convert the members to tensors to ensure that they are properly
        # typed and grouped before being passed to
        # tf.data.Dataset.from_tensor_slices.
        nested.append(tf.convert_to_tensor(value, type_spec.dtype))  # pytype: disable=attribute-error
    elif type_spec.is_struct():
        elements = structure.to_elements(type_spec)
        if isinstance(nested, collections.OrderedDict):
            if py_typecheck.is_named_tuple(value):
                value = value._asdict()  # pytype: disable=attribute-error
            if isinstance(value, dict):
                if set(value.keys()) != set(k for k, _ in elements):
                    raise TypeError('Value {} does not match type {}.'.format(
                        value, type_spec))
                for elem_name, elem_type in elements:
                    append_to_list_structure_for_element_type_spec(
                        nested[elem_name], value[elem_name], elem_type)
            elif isinstance(value, (list, tuple)):
                if len(value) != len(elements):
                    raise TypeError('Value {} does not match type {}.'.format(
                        value, type_spec))
                for idx, (elem_name, elem_type) in enumerate(elements):
                    append_to_list_structure_for_element_type_spec(
                        nested[elem_name], value[idx], elem_type)
            else:
                raise TypeError(
                    'Unexpected type of value {} for TFF type {}.'.format(
                        py_typecheck.type_string(type(value)), type_spec))
        elif isinstance(nested, tuple):
            py_typecheck.check_type(value, (list, tuple))
            if len(value) != len(elements):
                raise TypeError('Value {} does not match type {}.'.format(
                    value, type_spec))
            for idx, (_, elem_type) in enumerate(elements):
                append_to_list_structure_for_element_type_spec(
                    nested[idx], value[idx], elem_type)
        else:
            raise TypeError(
                'Invalid nested structure, unexpected container type {}.'.
                format(py_typecheck.type_string(type(nested))))
    else:
        raise TypeError(
            'Expected a tensor or named tuple type, found {}.'.format(
                type_spec))
def fetch_value_in_session(sess, value):
    """Fetches `value` in `session`.

  Args:
    sess: The session in which to perform the fetch (as a single run).
    value: A Python object of a form analogous to that constructed by the
      function `assemble_result_from_graph`, made of tensors and anononymous
      tuples, or a `tf.data.Dataset`.

  Returns:
    A Python object with structure similar to `value`, but with tensors
    replaced with their values, and data sets replaced with lists of their
    elements, all fetched with a single call `session.run()`.

  Raises:
    ValueError: If `value` is not a `tf.data.Dataset` or not a structure of
      tensors and anonoymous tuples.
  """
    py_typecheck.check_type(sess, tf.compat.v1.Session)
    # TODO(b/113123634): Investigate handling `list`s and `tuple`s of
    # `tf.data.Dataset`s and what the API would look like to support this.
    if isinstance(value, type_conversions.TF_DATASET_REPRESENTATION_TYPES):
        with sess.graph.as_default():
            iterator = tf.compat.v1.data.make_one_shot_iterator(value)
            next_element = iterator.get_next()
        elements = []
        while True:
            try:
                elements.append(sess.run(next_element))
            except tf.errors.OutOfRangeError:
                break
        return elements
    else:
        flattened_value = structure.flatten(value)
        dataset_results = {}
        flat_tensors = []
        for idx, v in enumerate(flattened_value):
            if isinstance(v, type_conversions.TF_DATASET_REPRESENTATION_TYPES):
                dataset_tensors = fetch_value_in_session(sess, v)
                if not dataset_tensors:
                    # An empty list has been returned; we must pack the shape information
                    # back in or the result won't typecheck.
                    element_structure = v.element_spec
                    dummy_elem = make_dummy_element_for_type_spec(
                        element_structure)
                    dataset_tensors = [dummy_elem]
                dataset_results[idx] = dataset_tensors
            elif tf.is_tensor(v):
                flat_tensors.append(v)
            else:
                raise ValueError('Unsupported value type {}.'.format(v))
        # Note that `flat_tensors` could be an empty tuple, but it could also be a
        # list of empty tuples.
        if flat_tensors or any(x for x in flat_tensors):
            flat_computed_tensors = sess.run(flat_tensors)
        else:
            flat_computed_tensors = flat_tensors
        flattened_results = _interleave_dataset_results_and_tensors(
            dataset_results, flat_computed_tensors)

        def _to_unicode(v):
            if isinstance(v, bytes):
                return v.decode('utf-8')
            return v

        if tf.is_tensor(value) and value.dtype == tf.string:
            flattened_results = [
                _to_unicode(result) for result in flattened_results
            ]
        return structure.pack_sequence_as(value, flattened_results)
Exemple #8
0
def from_keras_model(
    keras_model: tf.keras.Model,
    loss: Loss,
    input_spec,
    loss_weights: Optional[List[float]] = None,
    metrics: Optional[List[tf.keras.metrics.Metric]] = None
) -> model_lib.Model:
    """Builds a `tff.learning.Model` from a `tf.keras.Model`.

  The `tff.learning.Model` returned by this function uses `keras_model` for
  its forward pass and autodifferentiation steps.

  Notice that since TFF couples the `tf.keras.Model` and `loss`,
  TFF needs a slightly different notion of "fully specified type" than
  pure Keras does. That is, the model `M` takes inputs of type `x` and
  produces predictions of type `p`; the loss function `L` takes inputs of type
  `<p, y>` and produces a scalar. Therefore in order to fully specify the type
  signatures for computations in which the generated `tff.learning.Model` will
  appear, TFF needs the type `y` in addition to the type `x`.

  Args:
    keras_model: A `tf.keras.Model` object that is not compiled.
    loss: A `tf.keras.losses.Loss`, or a list of losses-per-output if the model
      has multiple outputs. If multiple outputs are present, the model will
      attempt to minimize the sum of all individual losses (optionally weighted
      using the `loss_weights` argument).
    input_spec: A structure of `tf.TensorSpec`s or `tff.Type` specifying the
      type of arguments the model expects. Notice this must be a compound
      structure of two elements, specifying both the data fed into the model (x)
      to generate predictions as well as the expected type of the ground truth
      (y). If provided as a list, it must be in the order [x, y]. If provided as
      a dictionary, the keys must explicitly be named `'x'` and `'y'`.
    loss_weights: (Optional) A list of Python floats used to weight the loss
      contribution of each model output.
    metrics: (Optional) a list of `tf.keras.metrics.Metric` objects.

  Returns:
    A `tff.learning.Model` object.

  Raises:
    TypeError: If `keras_model` is not instance of `tf.keras.Model`, if
      `keras_model` has a single output and `loss` is not instance of
      `tf.keras.losses.Loss`, or if `keras_model` has multiple outputs and
      `loss` is not a list of instances of `tf.keras.losses.Loss`.
    ValueError: If `keras_model` was compiled, if `keras_model` has multiple
      outputs and `loss` is not list of equal length, if `input_spec` does not
      contain exactly two elements, or if `input_spec` is a dictionary and does
      not contain keys `'x'` and `'y'`.
  """
    # Validate `keras_model`
    py_typecheck.check_type(keras_model, tf.keras.Model)
    if keras_model._is_compiled:  # pylint: disable=protected-access
        raise ValueError('`keras_model` must not be compiled')

    # Validate and normalize `loss` and `loss_weights`
    if len(keras_model.outputs) == 1:
        py_typecheck.check_type(loss, tf.keras.losses.Loss)
        if loss_weights is not None:
            raise ValueError(
                '`loss_weights` cannot be used if `keras_model` has '
                'only one output.')
        loss = [loss]
        loss_weights = [1.0]
    else:
        py_typecheck.check_type(loss, list)
        if len(loss) != len(keras_model.outputs):
            raise ValueError('`keras_model` must have equal number of '
                             'outputs and losses.\nloss: {}\nof length: {}.'
                             '\noutputs: {}\nof length: {}.'.format(
                                 loss, len(loss), keras_model.outputs,
                                 len(keras_model.outputs)))
        for loss_fn in loss:
            py_typecheck.check_type(loss_fn, tf.keras.losses.Loss)

        if loss_weights is None:
            loss_weights = [1.0] * len(loss)
        else:
            if len(loss) != len(loss_weights):
                raise ValueError(
                    '`keras_model` must have equal number of losses and loss_weights.'
                    '\nloss: {}\nof length: {}.'
                    '\nloss_weights: {}\nof length: {}.'.format(
                        loss, len(loss), loss_weights, len(loss_weights)))
            for loss_weight in loss_weights:
                py_typecheck.check_type(loss_weight, float)

    if len(input_spec) != 2:
        raise ValueError(
            'The top-level structure in `input_spec` must contain '
            'exactly two top-level elements, as it must specify type '
            'information for both inputs to and predictions from the '
            'model. You passed input spec {}.'.format(input_spec))
    if not isinstance(input_spec, computation_types.Type):
        for input_spec_member in tf.nest.flatten(input_spec):
            py_typecheck.check_type(input_spec_member, tf.TensorSpec)
    else:
        for type_elem in input_spec:
            py_typecheck.check_type(type_elem, computation_types.TensorType)
    if isinstance(input_spec, collections.Mapping):
        if 'x' not in input_spec:
            raise ValueError(
                'The `input_spec` is a collections.Mapping (e.g., a dict), so it '
                'must contain an entry with key `\'x\'`, representing the input(s) '
                'to the Keras model.')
        if 'y' not in input_spec:
            raise ValueError(
                'The `input_spec` is a collections.Mapping (e.g., a dict), so it '
                'must contain an entry with key `\'y\'`, representing the label(s) '
                'to be used in the Keras loss(es).')

    if metrics is None:
        metrics = []
    else:
        py_typecheck.check_type(metrics, list)
        for metric in metrics:
            py_typecheck.check_type(metric, tf.keras.metrics.Metric)

    return model_utils.enhance(
        _KerasModel(keras_model,
                    input_spec=input_spec,
                    loss_fns=loss,
                    loss_weights=loss_weights,
                    metrics=metrics))
def _client_fn(model,
               initial_model_weights,
               train_data,
               test_data,
               personalize_fn_dict,
               baseline_evaluate_fn,
               context=None):
  """The main `tf.function` that runs on device.

  This function first evalautes the initial model and gets the baseline metrics.
  Then starting from the same initial model, this function iterates over the
  personalization strategies defined in `personalize_fn_dict`, trains and
  evaluates the personalized models, and returns the evaluation metrics.

  Args:
    model: A `tff.learning.Model`.
    initial_model_weights: A `tff.learning.framework.ModelWeights` containing
      `tf.Tensor`s that hold trainable and non-trainable weights.
    train_data: A `tf.data.Dataset` used for training.
    test_data: A `tf.data.Dataset` used for evaluation.
    personalize_fn_dict: This is the same argument specified in the function
      `build_personalization_eval` above; see its documentation for details.
    baseline_evaluate_fn: This is the same argument specified in the function
      `build_personalization_eval` above; see its documentation for details.
    context: An optional object used in `personalize_fn_dict`. If used, its
      `tff.Type` must be provided by passing the correct `context_tff_type`
      argument to the `build_personalization_eval` function.

  Returns:
    An `OrderedDict` that maps a string 'baseline_metrics' to the evaluation
    metrics of the initial model (computed by `baseline_evaluate_fn`), and maps
    keys (strategy names) in `personalize_fn_dict` to the evaluation metrics of
    the corresponding personalization strategies.

  Raises:
    TypeError: If arguments are of the wrong types.
    ValueError: If `baseline_metrics` is used as a key in `personalize_fn_dict`.
  """
  # Wrap the input model as an `EnhancedModel` for easy access of its weights.
  model = model_utils.enhance(model)

  final_metrics = collections.OrderedDict()
  tff.utils.assign(model.weights, initial_model_weights)
  py_typecheck.check_callable(baseline_evaluate_fn)
  final_metrics['baseline_metrics'] = baseline_evaluate_fn(model, test_data)

  py_typecheck.check_type(personalize_fn_dict, collections.OrderedDict)
  if 'baseline_metrics' in personalize_fn_dict:
    raise ValueError('baseline_metrics should not be used as a key in '
                     'personalize_fn_dict.')

  for name, personalize_fn_builder in personalize_fn_dict.items():
    py_typecheck.check_type(name, str)
    tff.utils.assign(model.weights, initial_model_weights)

    # Construct the `personalize_fn` (and the associated `tf.Variable`s) here.
    # Once `_client_fn` is decorated with `tff.tf_computation`, construction of
    # the new variables will happen in a scope controlled by TFF. Ensuring
    # `tf.Variable`s are created in the graphs that TFF controls is the reason
    # we need `personalize_fn_dict` to contain no-argument functions that build
    # the desired `tf.function`s, rather than already built `tf.function`s.
    py_typecheck.check_callable(personalize_fn_builder)
    personalize_fn = personalize_fn_builder()

    py_typecheck.check_callable(personalize_fn)
    final_metrics[name] = personalize_fn(model, train_data, test_data, context)

  return final_metrics
Exemple #10
0
    async def create_value(self, value, type_spec=None):
        """Creates a value in this executor.

    The following kinds of `value` are supported as the input:

    * An instance of TFF computation proto containing one of the supported
      sequence intrinsics as its sole body.

    * An instance of eager TF dataset.

    * Anything that is supported by the target executor (as a pass-through).

    * A nested structure of any of the above.

    Args:
      value: The input for which to create a value.
      type_spec: An optional TFF type (required if `value` is not an instance of
        `typed_object.TypedObject`, otherwise it can be `None`).

    Returns:
      An instance of `SequenceExecutorValue` that represents the embedded value.
    """
        if type_spec is None:
            py_typecheck.check_type(value, typed_object.TypedObject)
            type_spec = value.type_signature
        else:
            type_spec = computation_types.to_type(type_spec)
        if isinstance(type_spec, computation_types.SequenceType):
            return SequenceExecutorValue(
                _SequenceFromPayload(value, type_spec), type_spec)
        if isinstance(value, pb.Computation):
            value_type = type_serialization.deserialize_type(value.type)
            value_type.check_equivalent_to(type_spec)
            which_computation = value.WhichOneof('computation')
            # NOTE: If not a supported type of intrinsic, we let it fall through and
            # be handled by embedding in the target executor (below).
            if which_computation == 'intrinsic':
                intrinsic_def = intrinsic_defs.uri_to_intrinsic_def(
                    value.intrinsic.uri)
                if intrinsic_def is None:
                    raise ValueError(
                        'Encountered an unrecognized intrinsic "{}".'.format(
                            value.intrinsic.uri))
                op_type = SequenceExecutor._SUPPORTED_INTRINSIC_TO_SEQUENCE_OP.get(
                    intrinsic_def.uri)
                if op_type is not None:
                    type_analysis.check_concrete_instance_of(
                        type_spec, intrinsic_def.type_signature)
                    op = op_type(type_spec)
                    return SequenceExecutorValue(op, type_spec)
        if isinstance(type_spec, computation_types.StructType):
            if not isinstance(value, structure.Struct):
                value = structure.from_container(value)
            elements = structure.flatten(value)
            element_types = structure.flatten(type_spec)
            flat_embedded_vals = await asyncio.gather(*[
                self.create_value(el, el_type)
                for el, el_type in zip(elements, element_types)
            ])
            embedded_struct = structure.pack_sequence_as(
                value, flat_embedded_vals)
            return await self.create_struct(embedded_struct)
        target_value = await self._target_executor.create_value(
            value, type_spec)
        return SequenceExecutorValue(target_value, type_spec)
Exemple #11
0
 def __init__(self, type_spec: computation_types.SequenceType):
     py_typecheck.check_type(type_spec, computation_types.SequenceType)
     self._type_signature = type_spec
Exemple #12
0
def building_block_to_computation(building_block):
    """Converts a computation building block to a computation impl."""
    py_typecheck.check_type(building_block,
                            building_blocks.ComputationBuildingBlock)
    return computation_impl.ComputationImpl(building_block.proto,
                                            context_stack_impl.context_stack)
def build_model_delta_optimizer_process(
    model_fn: _ModelConstructor,
    model_to_client_delta_fn: Callable[[Callable[[], model_lib.Model]],
                                       ClientDeltaFn],
    server_optimizer_fn: _OptimizerConstructor,
    *,
    broadcast_process: Optional[measured_process.MeasuredProcess] = None,
    aggregation_process: Optional[measured_process.MeasuredProcess] = None,
    model_update_aggregation_factory: Optional[
        factory.AggregationFactory] = None,
) -> iterative_process.IterativeProcess:
  """Constructs `tff.templates.IterativeProcess` for Federated Averaging or SGD.

  This provides the TFF orchestration logic connecting the common server logic
  which applies aggregated model deltas to the server model with a
  `ClientDeltaFn` that specifies how `weight_deltas` are computed on device.

  Note: We pass in functions rather than constructed objects so we can ensure
  any variables or ops created in constructors are placed in the correct graph.

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`.
    model_to_client_delta_fn: A function from a `model_fn` to a `ClientDeltaFn`.
    server_optimizer_fn: A no-arg function that returns a `tf.Optimizer`. The
      `apply_gradients` method of this optimizer is used to apply client updates
      to the server model.
    broadcast_process: A `tff.templates.MeasuredProcess` that broadcasts the
      model weights on the server to the clients. It must support the signature
      `(input_values@SERVER -> output_values@CLIENT)`.
    aggregation_process: A `tff.templates.MeasuredProcess` that aggregates the
      model updates on the clients back to the server. It must support the
      signature `({input_values}@CLIENTS-> output_values@SERVER)`. Must be
      `None` if `model_update_aggregation_factory` is not `None.`
    model_update_aggregation_factory: An optional
      `tff.aggregators.WeightedAggregationFactory` that contstructs
      `tff.templates.AggregationProcess` for aggregating the client model
      updates on the server. If `None`, uses a default constructed
      `tff.aggregators.MeanFactory`, creating a stateless mean aggregation. Must
      be `None` if `aggregation_process` is not `None.`

  Returns:
    A `tff.templates.IterativeProcess`.

  Raises:
    ProcessTypeError: if `broadcast_process` or `aggregation_process` do not
      conform to the signature of broadcast (SERVER->CLIENTS) or aggregation
      (CLIENTS->SERVER).
    DisjointArgumentError: if both `aggregation_process` and
      `model_update_aggregation_factory` are not `None`.
  """
  py_typecheck.check_callable(model_fn)
  py_typecheck.check_callable(model_to_client_delta_fn)
  py_typecheck.check_callable(server_optimizer_fn)

  model_weights_type = model_utils.weights_type_from_model(model_fn)

  if broadcast_process is None:
    broadcast_process = build_stateless_broadcaster(
        model_weights_type=model_weights_type)
  if not _is_valid_broadcast_process(broadcast_process):
    raise ProcessTypeError(
        'broadcast_process type signature does not conform to expected '
        'signature (<state@S, input@S> -> <state@S, result@C, measurements@S>).'
        ' Got: {t}'.format(t=broadcast_process.next.type_signature))

  if (model_update_aggregation_factory is not None and
      aggregation_process is not None):
    raise DisjointArgumentError(
        'Must specify only one of `model_update_aggregation_factory` and '
        '`AggregationProcess`.')

  if aggregation_process is None:
    if model_update_aggregation_factory is None:
      model_update_aggregation_factory = mean.MeanFactory()
    py_typecheck.check_type(model_update_aggregation_factory,
                            factory.AggregationFactory.__args__)
    if isinstance(model_update_aggregation_factory,
                  factory.WeightedAggregationFactory):
      aggregation_process = model_update_aggregation_factory.create(
          model_weights_type.trainable,
          computation_types.TensorType(tf.float32))
    else:
      aggregation_process = model_update_aggregation_factory.create(
          model_weights_type.trainable)
    process_signature = aggregation_process.next.type_signature
    input_client_value_type = process_signature.parameter[1]
    result_server_value_type = process_signature.result[1]
    if input_client_value_type.member != result_server_value_type.member:
      raise TypeError('`model_update_aggregation_factory` does not produce a '
                      'compatible `AggregationProcess`. The processes must '
                      'retain the type structure of the inputs on the '
                      f'server, but got {input_client_value_type.member} != '
                      f'{result_server_value_type.member}.')
  else:
    next_num_args = len(aggregation_process.next.type_signature.parameter)
    if next_num_args not in [2, 3]:
      raise ValueError(
          f'`next` function of `aggregation_process` must take two (for '
          f'unweighted aggregation) or three (for weighted aggregation) '
          f'arguments. Found {next_num_args}.')

  if not _is_valid_model_update_aggregation_process(aggregation_process):
    raise ProcessTypeError(
        'aggregation_process type signature does not conform to expected '
        'signature (<state@S, model_udpate@C> -> <state@S, model_update@S, '
        'measurements@S>). Got: {t}'.format(
            t=aggregation_process.next.type_signature))

  initialize_computation = _build_initialize_computation(
      model_fn=model_fn,
      server_optimizer_fn=server_optimizer_fn,
      broadcast_process=broadcast_process,
      aggregation_process=aggregation_process)

  run_one_round_computation = _build_one_round_computation(
      model_fn=model_fn,
      server_optimizer_fn=server_optimizer_fn,
      model_to_client_delta_fn=model_to_client_delta_fn,
      broadcast_process=broadcast_process,
      aggregation_process=aggregation_process)

  return iterative_process.IterativeProcess(
      initialize_fn=initialize_computation, next_fn=run_one_round_computation)
Exemple #14
0
def _trees_equal(comp_1, comp_2):
    """Returns `True` if the computations are entirely identical.

  If you pass objects other than instances of
  `building_blocks.ComputationBuildingBlock` this function will
  return `False`. Structurally equivalent computations with different variable
  names or different operation orderings are not considered to be equal.

  Args:
    comp_1: A `building_blocks.ComputationBuildingBlock` to test.
    comp_2: A `building_blocks.ComputationBuildingBlock` to test.

  Raises:
    TypeError: If `comp_1` or `comp_2` is not an instance of
      `building_blocks.ComputationBuildingBlock`.
    NotImplementedError: If `comp_1` and `comp_2` are an unexpected subclass of
      `building_blocks.ComputationBuildingBlock`.
  """
    # TODO(b/146892021): TFF needs a structural AST equality function, which
    # needs to be public. There is a necessary dependency on this function from
    # the TFF-to-TF code generation pipeline, in order to detect some structural
    # equivalence while generating TensorFlow. It was decided that it is
    # preferable to expose a dependency on this "private" function, and file the
    # bug here, rather than effectively duplicate the logic elsewhere.
    py_typecheck.check_type(comp_1, building_blocks.ComputationBuildingBlock)
    py_typecheck.check_type(comp_2, building_blocks.ComputationBuildingBlock)
    if comp_1 is comp_2:
        return True
    # The unidiomatic-typecheck is intentional, for the purposes of equality this
    # function requires that the types are identical and that a subclass will not
    # be equal to its baseclass.
    if type(comp_1) != type(comp_2):  # pylint: disable=unidiomatic-typecheck
        return False
    if comp_1.type_signature != comp_2.type_signature:
        return False
    if isinstance(comp_1, building_blocks.Block):
        if not _trees_equal(comp_1.result, comp_2.result):
            return False
        if len(comp_1.locals) != len(comp_2.locals):
            return False
        for (name_1, value_1), (name_2,
                                value_2) in zip(comp_1.locals, comp_2.locals):
            if name_1 != name_2 or not _trees_equal(value_1, value_2):
                return False
        return True
    elif isinstance(comp_1, building_blocks.Call):
        return (_trees_equal(comp_1.function, comp_2.function)
                and (comp_1.argument is None and comp_2.argument is None
                     or _trees_equal(comp_1.argument, comp_2.argument)))
    elif isinstance(comp_1, building_blocks.CompiledComputation):
        return _compiled_comp_equal(comp_1, comp_2)
    elif isinstance(comp_1, building_blocks.Data):
        return comp_1.uri == comp_2.uri
    elif isinstance(comp_1, building_blocks.Intrinsic):
        return comp_1.uri == comp_2.uri
    elif isinstance(comp_1, building_blocks.Lambda):
        return (comp_1.parameter_name == comp_2.parameter_name
                and comp_1.parameter_type == comp_2.parameter_type
                and _trees_equal(comp_1.result, comp_2.result))
    elif isinstance(comp_1, building_blocks.Placement):
        return comp_1.uri == comp_2.uri
    elif isinstance(comp_1, building_blocks.Reference):
        return comp_1.name == comp_2.name
    elif isinstance(comp_1, building_blocks.Selection):
        return (comp_1.name == comp_2.name and comp_1.index == comp_2.index
                and _trees_equal(comp_1.source, comp_2.source))
    elif isinstance(comp_1, building_blocks.Tuple):
        # The element names are checked as part of the `type_signature`.
        if len(comp_1) != len(comp_2):
            return False
        for element_1, element_2 in zip(comp_1, comp_2):
            if not _trees_equal(element_1, element_2):
                return False
        return True
    raise NotImplementedError('Unexpected type found: {}.'.format(
        type(comp_1)))
def assemble_result_from_graph(type_spec, binding, output_map):
    """Assembles a result stamped into a `tf.Graph` given type signature/binding.

  This method does roughly the opposite of `capture_result_from_graph`, in that
  whereas `capture_result_from_graph` starts with a single structured object
  made up of tensors and computes its type and bindings, this method starts
  with the type/bindings and constructs a structured object made up of tensors.

  Args:
    type_spec: The type signature of the result to assemble, an instance of
      `types.Type` or something convertible to it.
    binding: The binding that relates the type signature to names of tensors in
      the graph, an instance of `pb.TensorFlow.Binding`.
    output_map: The mapping from tensor names that appear in the binding to
      actual stamped tensors (possibly renamed during import).

  Returns:
    The assembled result, a Python object that is composed of tensors, possibly
    nested within Python structures such as anonymous tuples.

  Raises:
    TypeError: If the argument or any of its parts are of an uexpected type.
    ValueError: If the arguments are invalid or inconsistent witch other, e.g.,
      the type and binding don't match, or the tensor is not found in the map.
  """
    type_spec = computation_types.to_type(type_spec)
    py_typecheck.check_type(type_spec, computation_types.Type)
    py_typecheck.check_type(binding, pb.TensorFlow.Binding)
    py_typecheck.check_type(output_map, dict)
    for k, v in output_map.items():
        py_typecheck.check_type(k, str)
        if not tf.is_tensor(v):
            raise TypeError(
                'Element with key {} in the output map is {}, not a tensor.'.
                format(k, py_typecheck.type_string(type(v))))

    binding_oneof = binding.WhichOneof('binding')
    if type_spec.is_tensor():
        if binding_oneof != 'tensor':
            raise ValueError(
                'Expected a tensor binding, found {}.'.format(binding_oneof))
        elif binding.tensor.tensor_name not in output_map:
            raise ValueError(
                'Tensor named {} not found in the output map.'.format(
                    binding.tensor.tensor_name))
        else:
            return output_map[binding.tensor.tensor_name]
    elif type_spec.is_struct():
        if binding_oneof != 'struct':
            raise ValueError(
                'Expected a struct binding, found {}.'.format(binding_oneof))
        else:
            type_elements = structure.to_elements(type_spec)
            if len(binding.struct.element) != len(type_elements):
                raise ValueError(
                    'Mismatching tuple sizes in type ({}) and binding ({}).'.
                    format(len(type_elements), len(binding.struct.element)))
            result_elements = []
            for (element_name,
                 element_type), element_binding in zip(type_elements,
                                                       binding.struct.element):
                element_object = assemble_result_from_graph(
                    element_type, element_binding, output_map)
                result_elements.append((element_name, element_object))
            if type_spec.python_container is None:
                return structure.Struct(result_elements)
            container_type = type_spec.python_container
            if (py_typecheck.is_named_tuple(container_type)
                    or py_typecheck.is_attrs(container_type)):
                return container_type(**dict(result_elements))
            return container_type(result_elements)
    elif type_spec.is_sequence():
        if binding_oneof != 'sequence':
            raise ValueError(
                'Expected a sequence binding, found {}.'.format(binding_oneof))
        else:
            sequence_oneof = binding.sequence.WhichOneof('binding')
            if sequence_oneof == 'variant_tensor_name':
                variant_tensor = output_map[
                    binding.sequence.variant_tensor_name]
                return make_dataset_from_variant_tensor(
                    variant_tensor, type_spec.element)
            else:
                raise ValueError('Unsupported sequence binding \'{}\'.'.format(
                    sequence_oneof))
    else:
        raise ValueError('Unsupported type \'{}\'.'.format(type_spec))
def build_personalization_eval(model_fn,
                               personalize_fn_dict,
                               baseline_evaluate_fn,
                               max_num_samples=100,
                               context_tff_type=None):
  """Builds the TFF computation for evaluating personalization strategies.

  The returned TFF computation broadcasts model weights from SERVER to CLIENTS.
  Each client evaluates the personalization strategies given in
  `personalize_fn_dict`. Evaluation metrics from at most `max_num_samples`
  participating clients are collected to the SERVER.

  Args:
    model_fn: A no-argument function that returns a `tff.learning.Model`.
    personalize_fn_dict: An `OrderedDict` that maps a `string` (representing a
      strategy name) to a no-argument function that returns a `tf.function`.
      Each `tf.function` represents a personalization strategy: it accepts a
      `tff.learning.Model` (with weights already initialized to the provided
      model weights when users invoke the returned TFF computation), a training
      `tf.dataset.Dataset`, a test `tf.dataset.Dataset`, and an arbitrary
      context object (which is used to hold any extra information that a
      personalization strategy may use), trains a personalized model, and
      returns the evaluation metrics. The evaluation metrics are usually
      represented as an `OrderedDict` (or a nested `OrderedDict`) of `string`
      metric names to scalar `tf.Tensor`s.
    baseline_evaluate_fn: A `tf.function` that accepts a `tff.learning.Model`
      (with weights already initialized to the provided model weights when users
      invoke the returned TFF computation), and a `tf.dataset.Dataset`,
      evaluates the model on the dataset, and returns the evaluation metrics.
      The evaluation metrics are usually represented as an `OrderedDict` (or a
      nested `OrderedDict`) of `string` metric names to scalar `tf.Tensor`s.
      This function is *only* used to compute the baseline metrics of the
      initial model.
    max_num_samples: A positive `int` specifying the maximum number of metric
      samples to collect in a round. Each sample contains the personalization
      metrics from a single client. If the number of participating clients in a
      round is smaller than this value, all clients' metrics are collected.
    context_tff_type: A `tff.Type` of the optional context object used by the
      personalization strategies defined in `personalization_fn_dict`. We use a
      context object to hold any extra information (in addition to the training
      dataset) that personalization may use. If context is used in
      `personalization_fn_dict`, its `tff.Type` must be provided here.

  Returns:
    A federated `tff.Computation` that maps
    < model_weights@SERVER, input@CLIENTS > -> personalization_metrics@SERVER,
    where:
    - model_weights is a `tff.learning.framework.ModelWeights`.
    - each client's input is an `OrderedDict` of at least two keys `train_data`
      and `test_data`, and each key is mapped to a `tf.dataset.Dataset`. If
      context is used in `personalize_fn_dict`, then client input has a third
      key `context` that is mapped to a object whose `tff.Type` is provided by
      the `context_tff_type` argument.
    - personazliation_metrics is an `OrderedDict` that maps a key
      'baseline_metrics' to the evaluation metrics of the initial model
      (computed by `baseline_evaluate_fn`), and maps keys (strategy names) in
      `personalize_fn_dict` to the evaluation metrics of the corresponding
      personalization strategies.
    - Note: only metrics from at most `max_num_samples` participating clients
      are collected to the SERVER. All collected metrics are stored in a
      single `OrderedDict` (the personalization_metrics shown above), where each
      metric is mapped to a list of scalars (each scalar comes from one client).
      Metric values at the same position, e.g., metric_1[i], metric_2[i]..., all
      come from the same client.

  Raises:
    TypeError: If arguments are of the wrong types.
    ValueError: If `baseline_metrics` is used as a key in `personalize_fn_dict`.
    ValueError: If `max_num_samples` is not positive.
  """
  # Obtain the types by constructing the model first.
  # TODO(b/124477628): Replace it with other ways of handling metadata.
  with tf.Graph().as_default():
    py_typecheck.check_callable(model_fn)
    model = model_utils.enhance(model_fn())
    model_weights_type = tff.framework.type_from_tensors(model.weights)
    batch_type = tff.to_type(model.input_spec)

  # Define the `tff.Type` of each client's input.
  client_input_type = collections.OrderedDict([
      ('train_data', tff.SequenceType(batch_type)),
      ('test_data', tff.SequenceType(batch_type))
  ])
  if context_tff_type is not None:
    py_typecheck.check_type(context_tff_type, tff.Type)
    client_input_type['context'] = context_tff_type
  client_input_type = tff.to_type(client_input_type)

  @tff.tf_computation(model_weights_type, client_input_type)
  def _client_computation(initial_model_weights, client_input):
    """TFF computation that runs on each client."""
    model = model_fn()
    train_data = client_input['train_data']
    test_data = client_input['test_data']
    context = client_input.get('context', None)
    return _client_fn(model, initial_model_weights, train_data, test_data,
                      personalize_fn_dict, baseline_evaluate_fn, context)

  py_typecheck.check_type(max_num_samples, int)
  if max_num_samples <= 0:
    raise ValueError('max_num_samples must be a positive integer.')

  @tff.federated_computation(
      tff.FederatedType(model_weights_type, tff.SERVER),
      tff.FederatedType(client_input_type, tff.CLIENTS))
  def personalization_eval(server_model_weights, federated_client_input):
    """TFF orchestration logic."""
    client_init_weights = tff.federated_broadcast(server_model_weights)
    client_final_metrics = tff.federated_map(
        _client_computation, (client_init_weights, federated_client_input))

    # WARNING: Collecting information from clients can be risky. Users have to
    # make sure that it is proper to collect those metrics from clients.
    # TODO(b/147889283): Add a link to the TFF doc once it exists.
    results = tff.utils.federated_sample(client_final_metrics, max_num_samples)
    return results

  return personalization_eval
def get_tf_typespec_and_binding(parameter_type, arg_names, unpack=None):
    """Computes a `TensorSpec` input_signature and bindings for parameter_type.

  This is the TF2 analog to `stamp_parameter_in_graph`.

  Args:
    parameter_type: The TFF type of the input to a tensorflow function. Must be
      either an instance of computation_types.Type (or convertible to it), or
      None in the case of a no-arg function.
    arg_names: String names for any positional arguments to the tensorflow
      function.
    unpack: Whether or not to unpack parameter_type into args and kwargs. See
      e.g. `function_utils.pack_args_into_struct`.

  Returns:
    A tuple (args_typespec, kwargs_typespec, binding), where args_typespec is a
    list and kwargs_typespec is a dict, both containing `tf.TensorSpec`
    objects. These structures are intended to be passed to the
    `get_concrete_function` method of a `tf.function`.
    Note the "binding" is "preliminary" in that it includes the names embedded
    in the TensorSpecs produced; these must be converted to the names of actual
    tensors based on the SignatureDef of the SavedModel before the binding is
    finalized.
  """
    if parameter_type is None:
        return ([], {}, None)
    if unpack:
        arg_types, kwarg_types = function_utils.unpack_args_from_struct(
            parameter_type)
        pack_in_struct = True
    else:
        pack_in_struct = False
        arg_types, kwarg_types = [parameter_type], {}

    py_typecheck.check_type(arg_names, collections.Iterable)
    if len(arg_names) < len(arg_types):
        raise ValueError(
            'If provided, arg_names must be a list of at least {} strings to '
            'match the number of positional arguments. Found: {}'.format(
                len(arg_types), arg_names))

    get_unique_name = UniqueNameFn()

    def _get_one_typespec_and_binding(parameter_name, parameter_type):
        """Returns a (tf.TensorSpec, binding) pair."""
        parameter_type = computation_types.to_type(parameter_type)
        if parameter_type.is_tensor():
            name = get_unique_name(parameter_name)
            tf_spec = tf.TensorSpec(shape=parameter_type.shape,
                                    dtype=parameter_type.dtype,
                                    name=name)
            binding = pb.TensorFlow.Binding(tensor=pb.TensorFlow.TensorBinding(
                tensor_name=name))
            return (tf_spec, binding)
        elif parameter_type.is_struct():
            element_typespec_pairs = []
            element_bindings = []
            have_names = False
            have_nones = False
            for e_name, e_type in structure.iter_elements(parameter_type):
                if e_name is None:
                    have_nones = True
                else:
                    have_names = True
                name = '_'.join([n for n in [parameter_name, e_name] if n])
                e_typespec, e_binding = _get_one_typespec_and_binding(
                    name if name else None, e_type)
                element_typespec_pairs.append((e_name, e_typespec))
                element_bindings.append(e_binding)
            # For a given argument or kwarg, we shouldn't have both:
            if (have_names and have_nones):
                raise ValueError(
                    'A mix of named and unnamed entries are not supported inside a '
                    'nested structure representing a single argument in a call to a '
                    'TensorFlow or Python function.\n{}'.format(
                        parameter_type))
            tf_typespec = structure.Struct(element_typespec_pairs)
            return (tf_typespec,
                    pb.TensorFlow.Binding(struct=pb.TensorFlow.StructBinding(
                        element=element_bindings)))
        elif parameter_type.is_sequence():
            raise NotImplementedError(
                'Sequence iputs not yet supported for TF 2.0.')
        else:
            raise ValueError(
                'Parameter type component {!r} cannot be converted to a TensorSpec'
                .format(parameter_type))

    def get_arg_name(i):
        name = arg_names[i]
        if not isinstance(name, str):
            raise ValueError(
                'arg_names must be strings, but got: {}'.format(name))
        return name

    # Main logic --- process arg_types and kwarg_types:
    arg_typespecs = []
    kwarg_typespecs = {}
    bindings = []
    for i, arg_type in enumerate(arg_types):
        name = get_arg_name(i)
        typespec, binding = _get_one_typespec_and_binding(name, arg_type)
        typespec = type_conversions.type_to_py_container(typespec, arg_type)
        arg_typespecs.append(typespec)
        bindings.append(binding)
    for name, kwarg_type in kwarg_types.items():
        typespec, binding = _get_one_typespec_and_binding(name, kwarg_type)
        typespec = type_conversions.type_to_py_container(typespec, kwarg_type)
        kwarg_typespecs[name] = typespec
        bindings.append(binding)

    assert bindings, 'Given parameter_type {}, but produced no bindings.'.format(
        parameter_type)
    if pack_in_struct:
        final_binding = pb.TensorFlow.Binding(
            struct=pb.TensorFlow.StructBinding(element=bindings))
    else:
        final_binding = bindings[0]

    return (arg_typespecs, kwarg_typespecs, final_binding)
Exemple #18
0
def deserialize_and_call_tf_computation(computation_proto, arg, graph):
    """Deserializes a TF computation and inserts it into `graph`.

  This method performs an action that can be considered roughly the opposite of
  what `tensorflow_serialization.serialize_py_func_as_tf_computation` does. At
  the moment, it simply imports the graph in the current context. A future
  implementation may rely on different mechanisms. The caller should not be
  concerned with the specifics of the implementation. At this point, the method
  is expected to only be used within the body of another TF computation (within
  an instance of `tf_computation_context.TensorFlowComputationContext` at the
  top of the stack), and potentially also in certain types of interpreted
  execution contexts (TBD).

  Args:
    computation_proto: An instance of `pb.Computation` with the `computation`
      one of equal to `tensorflow` to be deserialized and called.
    arg: The argument to invoke the computation with, or None if the computation
      does not specify a parameter type and does not expects one.
    graph: The graph to stamp into.

  Returns:
    A tuple (init_op, result) where:
       init_op:  String name of an op to initialize the graph.
       result: The results to be fetched from TensorFlow. Depending on
           the type of the result, this can be `tf.Tensor` or `tf.data.Dataset`
           instances, or a nested structure (such as an
           `anonymous_tuple.AnonymousTuple`).

  Raises:
    TypeError: If the arguments are of the wrong types.
    ValueError: If `computation_proto` is not a TensorFlow computation proto.
  """
    py_typecheck.check_type(computation_proto, pb.Computation)
    computation_oneof = computation_proto.WhichOneof('computation')
    if computation_oneof != 'tensorflow':
        raise ValueError('Expected a TensorFlow computation, got {}.'.format(
            computation_oneof))
    py_typecheck.check_type(graph, tf.Graph)
    with graph.as_default():
        type_spec = type_serialization.deserialize_type(computation_proto.type)
        if not type_spec.parameter:
            if arg is None:
                input_map = None
            else:
                raise TypeError(
                    'The computation declared no parameters; encountered an unexpected '
                    'argument {}.'.format(str(arg)))
        elif arg is None:
            raise TypeError(
                'The computation declared a parameter of type {}, but the argument '
                'was not supplied.'.format(str(type_spec.parameter)))
        else:
            arg_type, arg_binding = graph_utils.capture_result_from_graph(
                arg, graph)
            if not type_utils.is_assignable_from(type_spec.parameter,
                                                 arg_type):
                raise TypeError(
                    'The computation declared a parameter of type {}, but the argument '
                    'is of a mismatching type {}.'.format(
                        str(type_spec.parameter), str(arg_type)))
            else:
                input_map = {
                    k: graph.get_tensor_by_name(v)
                    for k, v in six.iteritems(
                        graph_utils.compute_map_from_bindings(
                            computation_proto.tensorflow.parameter,
                            arg_binding))
                }
        return_elements = graph_utils.extract_tensor_names_from_binding(
            computation_proto.tensorflow.result)
        orig_init_op_name = computation_proto.tensorflow.initialize_op
        if orig_init_op_name:
            return_elements.append(orig_init_op_name)
        # N. B. Unlike MetaGraphDef, the GraphDef alone contains no information
        # about collections, and hence, when we import a graph with Variables,
        # those Variables are not added to global collections, and hence
        # functions like tf.global_variables_initializers() will not
        # contain their initialization ops.
        output_tensors = tf.import_graph_def(
            computation_proto.tensorflow.graph_def,
            input_map,
            return_elements,
            # N. B. It is very important not to return any names from the original
            # computation_proto.tensorflow.graph_def, those names might or might not
            # be valid in the current graph. Using a different scope makes the graph
            # somewhat more readable, since _N style de-duplication of graph
            # node names is less likely to be needed.
            name='subcomputation')

        output_map = {k: v for k, v in zip(return_elements, output_tensors)}
        new_init_op_name = output_map.pop(orig_init_op_name, None)
        return (new_init_op_name,
                graph_utils.assemble_result_from_graph(
                    type_spec.result, computation_proto.tensorflow.result,
                    output_map))
def make_data_set_from_elements(graph, elements, element_type):
    """Creates a `tf.data.Dataset` in `graph` from explicitly listed `elements`.

  Note: The underlying implementation attempts to use the
  `tf.data.Dataset.from_tensor_slices() method to build the data set quickly,
  but this doesn't always work. The typical scenario where it breaks is one
  with data set being composed of unequal batches. Typically, only the last
  batch is odd, so on the first attempt, we try to construct two data sets,
  one from all elements but the last one, and one from the last element, then
  concatenate the two. In the unlikely case that this fails (e.g., because
  all data set elements are batches of unequal sizes), we revert to the slow,
  but reliable method of constructing data sets from singleton elements, and
  then concatenating them all.

  Args:
    graph: The graph in which to construct the `tf.data.Dataset`, or `None` if
      the construction is to happen in the eager context.
    elements: A list of elements.
    element_type: The type of elements.

  Returns:
    The constructed `tf.data.Dataset` instance.

  Raises:
    TypeError: If element types do not match `element_type`.
    ValueError: If the elements are of incompatible types and shapes, or if
      no graph was specified outside of the eager context.
  """
    # Note: We allow the graph to be `None` to allow this function to be used in
    # the eager context.
    if graph is not None:
        py_typecheck.check_type(graph, tf.Graph)
    elif not tf.executing_eagerly():
        raise ValueError('Only in eager context may the graph be `None`.')
    py_typecheck.check_type(elements, list)
    element_type = computation_types.to_type(element_type)
    py_typecheck.check_type(element_type, computation_types.Type)

    def _make(element_subset):
        lists = make_empty_list_structure_for_element_type_spec(element_type)
        for el in element_subset:
            append_to_list_structure_for_element_type_spec(
                lists, el, element_type)
        tensor_slices = replace_empty_leaf_lists_with_numpy_arrays(
            lists, element_type)
        return tf.data.Dataset.from_tensor_slices(tensor_slices)

    def _work():  # pylint: disable=missing-docstring
        if not elements:
            # Just return an empty data set with the appropriate types.
            dummy_element = make_dummy_element_for_type_spec(element_type)
            ds = _make([dummy_element]).take(0)
        elif len(elements) == 1:
            ds = _make(elements)
        else:
            try:
                # It is common for the last element to be a batch of a size different
                # from all the preceding batches. With this in mind, we proactively
                # single out the last element (optimizing for the common case).
                ds = _make(elements[0:-1]).concatenate(_make(elements[-1:]))
            except ValueError:
                # In case elements beyond just the last one are of unequal shapes, we
                # may have failed (the most likely cause), so fall back onto the slow
                # process of constructing and joining data sets from singletons. Not
                # optimizing this for now, as it's very unlikely in scenarios
                # we're targeting.
                #
                # Note: this will not remain `None` because `element`s is not empty.
                ds = None
                ds = typing.cast(tf.data.Dataset, ds)
                for i in range(len(elements)):
                    singleton_ds = _make(elements[i:i + 1])
                    ds = singleton_ds if ds is None else ds.concatenate(
                        singleton_ds)
        ds_element_type = computation_types.to_type(ds.element_spec)
        if not element_type.is_assignable_from(ds_element_type):
            raise TypeError(
                'Failure during data set construction, expected elements of type {}, '
                'but the constructed data set has elements of type {}.'.format(
                    element_type, ds_element_type))
        return ds

    if graph is not None:
        with graph.as_default():
            return _work()
    else:
        return _work()
 async def create_value(self, value, type_spec=None):
     type_spec = computation_types.to_type(type_spec)
     if isinstance(value, intrinsic_defs.IntrinsicDef):
         if not type_utils.is_concrete_instance_of(type_spec,
                                                   value.type_signature):
             raise TypeError(
                 'Incompatible type {} used with intrinsic {}.'.format(
                     type_spec, value.uri))
         else:
             return FederatingExecutorValue(value, type_spec)
     if isinstance(value, placement_literals.PlacementLiteral):
         if type_spec is not None:
             py_typecheck.check_type(type_spec,
                                     computation_types.PlacementType)
         return FederatingExecutorValue(value,
                                        computation_types.PlacementType())
     elif isinstance(value, computation_impl.ComputationImpl):
         return await self.create_value(
             computation_impl.ComputationImpl.get_proto(value),
             type_utils.reconcile_value_with_type_spec(value, type_spec))
     elif isinstance(value, pb.Computation):
         if type_spec is None:
             type_spec = type_serialization.deserialize_type(value.type)
         which_computation = value.WhichOneof('computation')
         if which_computation in ['tensorflow', 'lambda']:
             return FederatingExecutorValue(value, type_spec)
         elif which_computation == 'reference':
             raise ValueError(
                 'Encountered an unexpected unbound references "{}".'.
                 format(value.reference.name))
         elif which_computation == 'intrinsic':
             intr = intrinsic_defs.uri_to_intrinsic_def(value.intrinsic.uri)
             if intr is None:
                 raise ValueError(
                     'Encountered an unrecognized intrinsic "{}".'.format(
                         value.intrinsic.uri))
             py_typecheck.check_type(intr, intrinsic_defs.IntrinsicDef)
             return await self.create_value(intr, type_spec)
         elif which_computation == 'placement':
             return await self.create_value(
                 placement_literals.uri_to_placement_literal(
                     value.placement.uri), type_spec)
         elif which_computation == 'call':
             parts = [value.call.function]
             if value.call.argument.WhichOneof('computation'):
                 parts.append(value.call.argument)
             parts = await asyncio.gather(
                 *[self.create_value(x) for x in parts])
             return await self.create_call(
                 parts[0], parts[1] if len(parts) > 1 else None)
         elif which_computation == 'tuple':
             element_values = await asyncio.gather(
                 *[self.create_value(x.value) for x in value.tuple.element])
             return await self.create_tuple(
                 anonymous_tuple.AnonymousTuple(
                     (e.name if e.name else None, v)
                     for e, v in zip(value.tuple.element, element_values)))
         elif which_computation == 'selection':
             which_selection = value.selection.WhichOneof('selection')
             if which_selection == 'name':
                 name = value.selection.name
                 index = None
             elif which_selection != 'index':
                 raise ValueError(
                     'Unrecognized selection type: "{}".'.format(
                         which_selection))
             else:
                 index = value.selection.index
                 name = None
             return await self.create_selection(await self.create_value(
                 value.selection.source),
                                                index=index,
                                                name=name)
         else:
             raise ValueError(
                 'Unsupported computation building block of type "{}".'.
                 format(which_computation))
     else:
         py_typecheck.check_type(type_spec, computation_types.Type)
         if isinstance(type_spec, computation_types.FunctionType):
             raise ValueError(
                 'Encountered a value of a functional TFF type {} and Python type '
                 '{} that is not of one of the recognized representations.'.
                 format(type_spec, py_typecheck.type_string(type(value))))
         elif isinstance(type_spec, computation_types.FederatedType):
             children = self._target_executors.get(type_spec.placement)
             if not children:
                 raise ValueError(
                     'Placement "{}" is not configured in this executor.'.
                     format(type_spec.placement))
             py_typecheck.check_type(children, list)
             if not type_spec.all_equal:
                 py_typecheck.check_type(value,
                                         (list, tuple, set, frozenset))
                 if not isinstance(value, list):
                     value = list(value)
             elif isinstance(value, list):
                 raise ValueError(
                     'An all_equal value should be passed directly, not as a list.'
                 )
             else:
                 value = [value for _ in children]
             if len(value) != len(children):
                 raise ValueError(
                     'Federated value contains {} items, but the placement {} in this '
                     'executor is configured with {} participants.'.format(
                         len(value), type_spec.placement, len(children)))
             child_vals = await asyncio.gather(*[
                 c.create_value(v, type_spec.member)
                 for v, c in zip(value, children)
             ])
             return FederatingExecutorValue(child_vals, type_spec)
         else:
             child = self._target_executors.get(None)
             if not child or len(child) > 1:
                 raise RuntimeError(
                     'Executor is not configured for unplaced values.')
             else:
                 return FederatingExecutorValue(
                     await child[0].create_value(value, type_spec),
                     type_spec)
Exemple #21
0
def local_executor_factory(
    num_clients=None,
    max_fanout=100,
    num_client_executors=32,
    server_tf_device=None,
    client_tf_devices=tuple()
) -> executor_factory.ExecutorFactory:
    """Constructs an executor factory to execute computations locally.

  Note: The `tff.federated_secure_sum()` intrinsic is not implemented by this
  executor.

  Args:
    num_clients: The number of clients. If specified, the executor factory
      function returned by `local_executor_factory` will be configured to have
      exactly `num_clients` clients. If unspecified (`None`), then the function
      returned will attempt to infer cardinalities of all placements for which
      it is passed values.
    max_fanout: The maximum fanout at any point in the aggregation hierarchy. If
      `num_clients > max_fanout`, the constructed executor stack will consist of
      multiple levels of aggregators. The height of the stack will be on the
      order of `log(num_clients) / log(max_fanout)`.
    num_client_executors: The number of distinct client executors to run
      concurrently; executing more clients than this number results in multiple
      clients having their work pinned on a single executor in a synchronous
      fashion.
    server_tf_device: A `tf.config.LogicalDevice` to place server and other
      computation without explicit TFF placement.
    client_tf_devices: List/tuple of `tf.config.LogicalDevice` to place clients
      for simulation. Possibly accelerators returned by
      `tf.config.list_logical_devices()`.

  Returns:
    An instance of `executor_factory.ExecutorFactory` encapsulating the
    executor construction logic specified above.

  Raises:
    ValueError: If the number of clients is specified and not one or larger.
  """
    if server_tf_device is not None:
        py_typecheck.check_type(server_tf_device, tf.config.LogicalDevice)
    py_typecheck.check_type(client_tf_devices, (tuple, list))
    py_typecheck.check_type(max_fanout, int)
    py_typecheck.check_type(num_client_executors, int)
    if num_clients is not None:
        py_typecheck.check_type(num_clients, int)
    if max_fanout < 2:
        raise ValueError('Max fanout must be greater than 1.')
    unplaced_ex_factory = UnplacedExecutorFactory(
        use_caching=True,
        server_device=server_tf_device,
        client_devices=client_tf_devices)
    federating_executor_factory = FederatingExecutorFactory(
        num_client_executors=num_client_executors,
        unplaced_ex_factory=unplaced_ex_factory,
        num_clients=num_clients,
        use_sizing=False)

    def _factory_fn(
        cardinalities: executor_factory.CardinalitiesType
    ) -> executor_base.Executor:
        return _create_full_stack(
            cardinalities,
            max_fanout,
            stack_func=federating_executor_factory.create_executor,
            unplaced_ex_factory=unplaced_ex_factory)

    return executor_factory.ExecutorFactoryImpl(_factory_fn)
 def _check_arg_is_anonymous_tuple(self, arg):
     py_typecheck.check_type(arg.type_signature,
                             computation_types.NamedTupleType)
     py_typecheck.check_type(arg.internal_representation,
                             anonymous_tuple.AnonymousTuple)
Exemple #23
0
def _make_wrapper(clipping_norm: Union[float,
                                       estimation_process.EstimationProcess],
                  inner_agg_factory: factory.AggregationFactory,
                  make_clip_fn: Callable[[factory.ValueType],
                                         computation_base.Computation],
                  attribute_prefix: str) -> factory.AggregationFactory:
  """Constructs an aggregation factory that applies clip_fn before aggregation.

  Args:
    clipping_norm: Either a float (for fixed norm) or an `EstimationProcess`
      (for adaptive norm) that specifies the norm over which the values should
      be clipped.
    inner_agg_factory: A factory specifying the type of aggregation to be done
      after zeroing.
    make_clip_fn: A callable that takes a value type and returns a
      tff.computation specifying the clip operation to apply before aggregation.
    attribute_prefix: A str for prefixing state and measurement names.

  Returns:
    An aggregation factory that applies clip_fn before aggregation.
  """
  py_typecheck.check_type(inner_agg_factory,
                          (factory.UnweightedAggregationFactory,
                           factory.WeightedAggregationFactory))
  py_typecheck.check_type(clipping_norm,
                          (float, estimation_process.EstimationProcess))
  if isinstance(clipping_norm, float):
    clipping_norm_process = _constant_process(clipping_norm)
  else:
    clipping_norm_process = clipping_norm
  _check_norm_process(clipping_norm_process, 'clipping_norm_process')

  # The aggregation factory that will be used to count the number of clipped
  # values at each iteration. For now we are just creating it here, but in
  # the future we may make this customizable to allow DP measurements.
  clipped_count_agg_factory = sum_factory.SumFactory()

  clipped_count_agg_process = clipped_count_agg_factory.create(
      computation_types.to_type(COUNT_TF_TYPE))

  prefix = lambda s: attribute_prefix + s

  def init_fn_impl(inner_agg_process):
    state = collections.OrderedDict([
        (prefix('ing_norm'), clipping_norm_process.initialize()),
        ('inner_agg', inner_agg_process.initialize()),
        (prefix('ed_count_agg'), clipped_count_agg_process.initialize())
    ])
    return intrinsics.federated_zip(state)

  def next_fn_impl(state, value, clip_fn, inner_agg_process, weight=None):
    clipping_norm_state, agg_state, clipped_count_state = state

    clipping_norm = clipping_norm_process.report(clipping_norm_state)

    clients_clipping_norm = intrinsics.federated_broadcast(clipping_norm)

    # TODO(b/163880757): Remove this when server-only metrics are supported.
    clipping_norm = intrinsics.federated_mean(clients_clipping_norm)

    clipped_value, global_norm, was_clipped = intrinsics.federated_map(
        clip_fn, (value, clients_clipping_norm))

    new_clipping_norm_state = clipping_norm_process.next(
        clipping_norm_state, global_norm)

    if weight is None:
      agg_output = inner_agg_process.next(agg_state, clipped_value)
    else:
      agg_output = inner_agg_process.next(agg_state, clipped_value, weight)

    clipped_count_output = clipped_count_agg_process.next(
        clipped_count_state, was_clipped)

    new_state = collections.OrderedDict([
        (prefix('ing_norm'), new_clipping_norm_state),
        ('inner_agg', agg_output.state),
        (prefix('ed_count_agg'), clipped_count_output.state)
    ])
    measurements = collections.OrderedDict([
        (prefix('ing'), agg_output.measurements),
        (prefix('ing_norm'), clipping_norm),
        (prefix('ed_count'), clipped_count_output.result)
    ])

    return measured_process.MeasuredProcessOutput(
        state=intrinsics.federated_zip(new_state),
        result=agg_output.result,
        measurements=intrinsics.federated_zip(measurements))

  if isinstance(inner_agg_factory, factory.WeightedAggregationFactory):

    class WeightedRobustFactory(factory.WeightedAggregationFactory):
      """`WeightedAggregationFactory` factory for clipping large values."""

      def create(
          self, value_type: factory.ValueType, weight_type: factory.ValueType
      ) -> aggregation_process.AggregationProcess:
        _check_value_type(value_type)
        py_typecheck.check_type(weight_type, factory.ValueType.__args__)

        inner_agg_process = inner_agg_factory.create(value_type, weight_type)
        clip_fn = make_clip_fn(value_type)

        @computations.federated_computation()
        def init_fn():
          return init_fn_impl(inner_agg_process)

        @computations.federated_computation(
            init_fn.type_signature.result,
            computation_types.at_clients(value_type),
            computation_types.at_clients(weight_type))
        def next_fn(state, value, weight):
          return next_fn_impl(state, value, clip_fn, inner_agg_process, weight)

        return aggregation_process.AggregationProcess(init_fn, next_fn)

    return WeightedRobustFactory()
  else:

    class UnweightedRobustFactory(factory.UnweightedAggregationFactory):
      """`UnweightedAggregationFactory` factory for clipping large values."""

      def create(
          self, value_type: factory.ValueType
      ) -> aggregation_process.AggregationProcess:
        _check_value_type(value_type)

        inner_agg_process = inner_agg_factory.create(value_type)
        clip_fn = make_clip_fn(value_type)

        @computations.federated_computation()
        def init_fn():
          return init_fn_impl(inner_agg_process)

        @computations.federated_computation(
            init_fn.type_signature.result,
            computation_types.at_clients(value_type))
        def next_fn(state, value):
          return next_fn_impl(state, value, clip_fn, inner_agg_process)

        return aggregation_process.AggregationProcess(init_fn, next_fn)

    return UnweightedRobustFactory()
Exemple #24
0
 async def compute(self):
     # TODO(b/153499219): Add support for values of other types than tensors.
     py_typecheck.check_type(self._type_signature,
                             computation_types.TensorType)
     return self._value
Exemple #25
0
    def __init__(self,
                 upper_bound_threshold: ThresholdEstType,
                 lower_bound_threshold: Optional[ThresholdEstType] = None):
        """Initializes `SecureSumFactory`.

    Args:
      upper_bound_threshold: Either a `int` or `float` Python constant, a Numpy
        scalar, or a `tff.templates.EstimationProcess`, used for determining the
        upper bound before summation.
      lower_bound_threshold: Optional. Either a `int` or `float` Python
        constant, a Numpy scalar, or a `tff.templates.EstimationProcess`, used
        for determining the lower bound before summation. If specified, must be
        the same type as `upper_bound_threshold`.

    Raises:
      TypeError: If `upper_bound_threshold` and `lower_bound_threshold` are not
        instances of one of (`int`, `float` or
        `tff.templates.EstimationProcess`).
      ValueError: If `upper_bound_threshold` is provided as a negative constant.
    """
        py_typecheck.check_type(upper_bound_threshold,
                                ThresholdEstType.__args__)
        if lower_bound_threshold is not None:
            if not isinstance(lower_bound_threshold,
                              type(upper_bound_threshold)):
                raise TypeError(
                    f'Provided upper_bound_threshold and lower_bound_threshold '
                    f'must have the same types, but found:\n'
                    f'type(upper_bound_threshold): {upper_bound_threshold}\n'
                    f'type(lower_bound_threshold): {lower_bound_threshold}')

        # Configuration specific for aggregating integer types.
        if _is_integer(upper_bound_threshold):
            self._config_mode = _Config.INT
            if lower_bound_threshold is None:
                _check_positive(upper_bound_threshold)
                lower_bound_threshold = -1 * upper_bound_threshold
            else:
                _check_upper_larger_than_lower(upper_bound_threshold,
                                               lower_bound_threshold)
            self._init_fn = _empty_state
            self._get_bounds_from_state = _create_get_bounds_const(
                upper_bound_threshold, lower_bound_threshold)
            self._update_state = lambda _, __, ___: _empty_state()
            self._secagg_bitwidth = math.ceil(
                math.log2(upper_bound_threshold - lower_bound_threshold))

        # Configuration specific for aggregating floating point types.
        else:
            self._config_mode = _Config.FLOAT
            if _is_float(upper_bound_threshold):
                # Bounds specified as Python constants.
                if lower_bound_threshold is None:
                    _check_positive(upper_bound_threshold)
                    lower_bound_threshold = -1.0 * upper_bound_threshold
                else:
                    _check_upper_larger_than_lower(upper_bound_threshold,
                                                   lower_bound_threshold)
                self._get_bounds_from_state = _create_get_bounds_const(
                    upper_bound_threshold, lower_bound_threshold)
                self._init_fn = _empty_state
                self._update_state = lambda _, __, ___: _empty_state()
            else:
                # Bounds specified as an EstimationProcess.
                _check_bound_process(upper_bound_threshold,
                                     'upper_bound_threshold')
                if lower_bound_threshold is None:
                    self._get_bounds_from_state = _create_get_bounds_single_process(
                        upper_bound_threshold)
                    self._init_fn = upper_bound_threshold.initialize
                    self._update_state = _create_update_state_single_process(
                        upper_bound_threshold)
                else:
                    _check_bound_process(lower_bound_threshold,
                                         'lower_bound_threshold')
                    self._get_bounds_from_state = _create_get_bounds_two_processes(
                        upper_bound_threshold, lower_bound_threshold)
                    self._init_fn = _create_initial_state_two_processes(
                        upper_bound_threshold, lower_bound_threshold)
                    self._update_state = _create_update_state_two_processes(
                        upper_bound_threshold, lower_bound_threshold)
def coerce_dataset_elements_to_tff_type_spec(dataset, element_type):
    """Map the elements of a dataset to a specified type.

  This is used to coerce a `tf.data.Dataset` that may have lost the ordering
  of dictionary keys back into a `collections.OrderedDict` (required by TFF).

  Args:
    dataset: a `tf.data.Dataset` instance.
    element_type: a `tff.Type` specifying the type of the elements of `dataset`.
      Must be a `tff.TensorType` or `tff.StructType`.

  Returns:
    A `tf.data.Dataset` whose output types are compatible with
    `element_type`.

  Raises:
    ValueError: if the elements of `dataset` cannot be coerced into
      `element_type`.
  """
    py_typecheck.check_type(dataset,
                            type_conversions.TF_DATASET_REPRESENTATION_TYPES)
    py_typecheck.check_type(element_type, computation_types.Type)

    if element_type.is_tensor():
        return dataset

    # This is a similar to `reference_executor.to_representation_for_type`,
    # look for opportunities to consolidate?
    def _to_representative_value(type_spec, elements):
        """Convert to a container to a type understood by TF and TFF."""
        if type_spec.is_tensor():
            return elements
        elif type_spec.is_struct():
            field_types = structure.to_elements(type_spec)
            is_all_named = all([name is not None for name, _ in field_types])
            if is_all_named:
                if py_typecheck.is_named_tuple(elements):
                    values = collections.OrderedDict(
                        (name, _to_representative_value(field_type, e))
                        for (name,
                             field_type), e in zip(field_types, elements))
                    return type(elements)(**values)
                else:
                    values = [
                        (name,
                         _to_representative_value(field_type, elements[name]))
                        for name, field_type in field_types
                    ]
                    return collections.OrderedDict(values)
            else:
                return tuple(
                    _to_representative_value(t, e)
                    for t, e in zip(type_spec, elements))
        else:
            raise ValueError(
                'Coercing a dataset with elements of expected type {!s}, '
                'produced a value with incompatible type `{!s}. Value: '
                '{!s}'.format(type_spec, type(elements), elements))

    # tf.data.Dataset of tuples will unwrap the tuple in the `map()` call, so we
    # must pass a function taking *args. However, if the call was originally only
    # a single tuple, it is now "double wrapped" and must be unwrapped before
    # traversing.
    def _unwrap_args(*args):
        if len(args) == 1:
            return _to_representative_value(element_type, args[0])
        else:
            return _to_representative_value(element_type, args)

    return dataset.map(_unwrap_args)
def embed_tensorflow_computation(comp, type_spec=None, device=None):
  """Embeds a TensorFlow computation for use in the eager context.

  Args:
    comp: An instance of `pb.Computation`.
    type_spec: An optional `tff.Type` instance or something convertible to it.
    device: An optional device name.

  Returns:
    Either a one-argument or a zero-argument callable that executes the
    computation in eager mode.

  Raises:
    TypeError: If arguments are of the wrong types, e.g., in `comp` is not a
      TensorFlow computation.
  """
  # TODO(b/134543154): Decide whether this belongs in `graph_utils.py` since
  # it deals exclusively with eager mode. Incubate here, and potentially move
  # there, once stable.

  if device is not None:
    raise NotImplementedError('Unable to embed TF code on a specific device.')

  py_typecheck.check_type(comp, pb.Computation)
  comp_type = type_serialization.deserialize_type(comp.type)
  type_spec = computation_types.to_type(type_spec)
  if type_spec is not None:
    if not type_utils.are_equivalent_types(type_spec, comp_type):
      raise TypeError('Expected a computation of type {}, got {}.'.format(
          str(type_spec), str(comp_type)))
  else:
    type_spec = comp_type
  which_computation = comp.WhichOneof('computation')
  if which_computation != 'tensorflow':
    raise TypeError('Expected a TensorFlow computation, found {}.'.format(
        which_computation))

  if isinstance(type_spec, computation_types.FunctionType):
    param_type = type_spec.parameter
    result_type = type_spec.result
  else:
    param_type = None
    result_type = type_spec

  if param_type is not None:
    input_tensor_names = graph_utils.extract_tensor_names_from_binding(
        comp.tensorflow.parameter)
  else:
    input_tensor_names = []

  output_tensor_names = graph_utils.extract_tensor_names_from_binding(
      comp.tensorflow.result)

  def function_to_wrap(*args):  # pylint: disable=missing-docstring
    if len(args) != len(input_tensor_names):
      raise RuntimeError('Expected {} arguments, found {}.'.format(
          str(len(input_tensor_names)), str(len(args))))
    graph_def = serialization_utils.unpack_graph_def(comp.tensorflow.graph_def)
    init_op = comp.tensorflow.initialize_op
    if init_op:
      graph_def = graph_utils.add_control_deps_for_init_op(graph_def, init_op)
    return tf.import_graph_def(
        graph_merge.uniquify_shared_names(graph_def),
        input_map=dict(zip(input_tensor_names, args)),
        return_elements=output_tensor_names)

  signature = []
  param_fns = []
  if param_type is not None:
    for spec in anonymous_tuple.flatten(type_spec.parameter):
      if isinstance(spec, computation_types.TensorType):
        signature.append(tf.TensorSpec(spec.shape, spec.dtype))
        param_fns.append(lambda x: x)
      else:
        py_typecheck.check_type(spec, computation_types.SequenceType)
        signature.append(tf.TensorSpec([], tf.variant))
        param_fns.append(tf.data.experimental.to_variant)

  wrapped_fn = tf.compat.v1.wrap_function(function_to_wrap, signature)

  result_fns = []
  for spec in anonymous_tuple.flatten(result_type):
    if isinstance(spec, computation_types.TensorType):
      result_fns.append(lambda x: x)
    else:
      py_typecheck.check_type(spec, computation_types.SequenceType)
      structure = type_utils.type_to_tf_structure(spec.element)

      def fn(x, structure=structure):
        return tf.data.experimental.from_variant(x, structure)

      result_fns.append(fn)

  def _fn_to_return(arg, param_fns, wrapped_fn):  # pylint:disable=missing-docstring
    param_elements = []
    if arg is not None:
      arg_parts = anonymous_tuple.flatten(arg)
      if len(arg_parts) != len(param_fns):
        raise RuntimeError('Expected {} arguments, found {}.'.format(
            str(len(param_fns)), str(len(arg_parts))))
      for arg_part, param_fn in zip(arg_parts, param_fns):
        param_elements.append(param_fn(arg_part))
    result_parts = wrapped_fn(*param_elements)
    result_elements = []
    for result_part, result_fn in zip(result_parts, result_fns):
      result_elements.append(result_fn(result_part))
    return anonymous_tuple.pack_sequence_as(result_type, result_elements)

  fn_to_return = lambda arg, p=param_fns, w=wrapped_fn: _fn_to_return(arg, p, w)
  if param_type is not None:
    return lambda arg: fn_to_return(arg)  # pylint: disable=unnecessary-lambda
  else:
    return lambda: fn_to_return(None)
def stamp_parameter_in_graph(parameter_name, parameter_type, graph):
    """Stamps a parameter of a given type in the given tf.Graph instance.

  Tensors are stamped as placeholders, sequences are stamped as data sets
  constructed from string tensor handles, and named tuples are stamped by
  independently stamping their elements.

  Args:
    parameter_name: The suggested (string) name of the parameter to use in
      determining the names of the graph components to construct. The names that
      will actually appear in the graph are not guaranteed to be based on this
      suggested name, and may vary, e.g., due to existing naming conflicts, but
      a best-effort attempt will be made to make them similar for ease of
      debugging.
    parameter_type: The type of the parameter to stamp. Must be either an
      instance of computation_types.Type (or convertible to it), or None.
    graph: The instance of tf.Graph to stamp in.

  Returns:
    A tuple (val, binding), where 'val' is a Python object (such as a dataset,
    a placeholder, or a `structure.Struct` that represents a named
    tuple) that represents the stamped parameter for use in the body of a Python
    function that consumes this parameter, and the 'binding' is an instance of
    TensorFlow.Binding that indicates how parts of the type signature relate
    to the tensors and ops stamped into the graph.

  Raises:
    TypeError: If the arguments are of the wrong computation_types.
    ValueError: If the parameter type cannot be stamped in a TensorFlow graph.
  """
    py_typecheck.check_type(parameter_name, str)
    py_typecheck.check_type(graph, tf.Graph)
    if parameter_type is None:
        return (None, None)
    parameter_type = computation_types.to_type(parameter_type)
    if parameter_type.is_tensor():
        with graph.as_default():
            placeholder = tf.compat.v1.placeholder(dtype=parameter_type.dtype,
                                                   shape=parameter_type.shape,
                                                   name=parameter_name)
            binding = pb.TensorFlow.Binding(tensor=pb.TensorFlow.TensorBinding(
                tensor_name=placeholder.name))
            return (placeholder, binding)
    elif parameter_type.is_struct():
        # The parameter_type could be a StructTypeWithPyContainer, however, we
        # ignore that for now. Instead, the proper containers will be inserted at
        # call time by function_utils.wrap_as_zero_or_one_arg_callable.
        if not parameter_type:
            # Stamps dummy element to "populate" graph, as TensorFlow does not support
            # empty graphs.
            dummy_tensor = tf.no_op()
        element_name_value_pairs = []
        element_bindings = []
        for e in structure.iter_elements(parameter_type):
            e_val, e_binding = stamp_parameter_in_graph(
                '{}_{}'.format(parameter_name, e[0]), e[1], graph)
            element_name_value_pairs.append((e[0], e_val))
            element_bindings.append(e_binding)
        return (structure.Struct(element_name_value_pairs),
                pb.TensorFlow.Binding(struct=pb.TensorFlow.StructBinding(
                    element=element_bindings)))
    elif parameter_type.is_sequence():
        with graph.as_default():
            variant_tensor = tf.compat.v1.placeholder(tf.variant, shape=[])
            ds = make_dataset_from_variant_tensor(variant_tensor,
                                                  parameter_type.element)
        return (ds,
                pb.TensorFlow.Binding(sequence=pb.TensorFlow.SequenceBinding(
                    variant_tensor_name=variant_tensor.name)))
    else:
        raise ValueError(
            'Parameter type component {!r} cannot be stamped into a TensorFlow '
            'graph.'.format(parameter_type))
 def _check_arg_is_structure(self, arg):
   py_typecheck.check_type(arg.type_signature, computation_types.StructType)
   py_typecheck.check_type(arg.internal_representation, structure.Struct)
Exemple #30
0
def pack_args_into_anonymous_tuple(args, kwargs, type_spec=None, context=None):
    """Packs positional and keyword arguments into an anonymous tuple.

  If 'type_spec' is not None, it must be a tuple type or something that's
  convertible to it by computation_types.to_type(). The assignment of arguments
  to fields of the tuple follows the same rule as during function calls. If
  'type_spec' is None, the positional arguments precede any of the keyword
  arguments, and the ordering of the keyword arguments matches the ordering in
  which they appear in kwargs. If the latter is an OrderedDict, the ordering
  will be preserved. On the other hand, if the latter is an ordinary unordered
  dict, the ordering is arbitrary.

  Args:
    args: Positional arguments.
    kwargs: Keyword arguments.
    type_spec: The optional type specification (either an instance of
      computation_types.NamedTupleType or something convertible to it), or None
      if there's no type. Used to drive the arrangements of args into fields of
      the constructed anonymous tuple, as noted in the description.
    context: The optional context (an instance of `context_base.Context`) in
      which the arguments are being packed. Required if and only if the
      `type_spec` is not `None`.

  Returns:
    An anoymous tuple containing all the arguments.

  Raises:
    TypeError: if the arguments are of the wrong computation_types.
  """
    type_spec = computation_types.to_type(type_spec)
    if not type_spec:
        return anonymous_tuple.AnonymousTuple([(None, arg) for arg in args] +
                                              list(six.iteritems(kwargs)))
    else:
        py_typecheck.check_type(type_spec, computation_types.NamedTupleType)
        py_typecheck.check_type(context, context_base.Context)
        if not is_argument_tuple(type_spec):
            raise TypeError(
                'Parameter type {} does not have a structure of an argument tuple, '
                'and cannot be populated from multiple positional and keyword '
                'arguments'.format(type_spec))
        else:
            result_elements = []
            positions_used = set()
            keywords_used = set()
            for index, (name, elem_type) in enumerate(
                    anonymous_tuple.to_elements(type_spec)):
                if index < len(args):
                    if name is not None and name in kwargs:
                        raise TypeError(
                            'Argument {} specified twice.'.format(name))
                    else:
                        arg_value = args[index]
                        result_elements.append(
                            (name, context.ingest(arg_value, elem_type)))
                        positions_used.add(index)
                elif name is not None and name in kwargs:
                    arg_value = kwargs[name]
                    result_elements.append(
                        (name, context.ingest(arg_value, elem_type)))
                    keywords_used.add(name)
                elif name:
                    raise TypeError(
                        'Argument named {} is missing.'.format(name))
                else:
                    raise TypeError(
                        'Argument at position {} is missing.'.format(index))
            positions_missing = set(range(
                len(args))).difference(positions_used)
            if positions_missing:
                raise TypeError('Positional arguments at {} not used.'.format(
                    positions_missing))
            keywords_missing = set(kwargs.keys()).difference(keywords_used)
            if keywords_missing:
                raise TypeError('Keyword arguments at {} not used.'.format(
                    keywords_missing))
            return anonymous_tuple.AnonymousTuple(result_elements)