예제 #1
0
 def __init__(self, test, uri, value, type_spec):
     self._test = test
     self._uri = uri
     self._value = value
     self._type_spec = computation_types.to_type(type_spec)
예제 #2
0
  def create(
      self,
      value_type: factory.ValueType) -> aggregation_process.AggregationProcess:
    # Validate input args and value_type and parse out the TF dtypes.
    if value_type.is_tensor():
      tf_dtype = value_type.dtype
    elif (value_type.is_struct_with_python() and
          type_analysis.is_structure_of_tensors(value_type)):
      tf_dtype = type_conversions.structure_from_tensor_type_tree(
          lambda x: x.dtype, value_type)
    else:
      raise TypeError('Expected `value_type` to be `TensorType` or '
                      '`StructWithPythonType` containing only `TensorType`. '
                      f'Found type: {repr(value_type)}')

    # Check that all values are floats.
    if not type_analysis.is_structure_of_floats(value_type):
      raise TypeError('Component dtypes of `value_type` must all be floats. '
                      f'Found {repr(value_type)}.')

    if self._distortion_aggregation_factory is not None:
      distortion_aggregation_process = self._distortion_aggregation_factory.create(
          computation_types.to_type(tf.float32))

    @tensorflow_computation.tf_computation(value_type, tf.float32)
    def discretize_fn(value, step_size):
      return _discretize_struct(value, step_size)

    @tensorflow_computation.tf_computation(discretize_fn.type_signature.result,
                                           tf.float32)
    def undiscretize_fn(value, step_size):
      return _undiscretize_struct(value, step_size, tf_dtype)

    @tensorflow_computation.tf_computation(value_type, tf.float32)
    def distortion_measurement_fn(value, step_size):
      reconstructed_value = undiscretize_fn(
          discretize_fn(value, step_size), step_size)
      err = tf.nest.map_structure(tf.subtract, reconstructed_value, value)
      squared_err = tf.nest.map_structure(tf.square, err)
      flat_squared_errs = [
          tf.cast(tf.reshape(t, [-1]), tf.float32)
          for t in tf.nest.flatten(squared_err)
      ]
      all_squared_errs = tf.concat(flat_squared_errs, axis=0)
      mean_squared_err = tf.reduce_mean(all_squared_errs)
      return mean_squared_err

    inner_agg_process = self._inner_agg_factory.create(
        discretize_fn.type_signature.result)

    @federated_computation.federated_computation()
    def init_fn():
      state = collections.OrderedDict(
          step_size=intrinsics.federated_value(self._step_size,
                                               placements.SERVER),
          inner_agg_process=inner_agg_process.initialize())
      return intrinsics.federated_zip(state)

    @federated_computation.federated_computation(
        init_fn.type_signature.result, computation_types.at_clients(value_type))
    def next_fn(state, value):
      server_step_size = state['step_size']
      client_step_size = intrinsics.federated_broadcast(server_step_size)

      discretized_value = intrinsics.federated_map(discretize_fn,
                                                   (value, client_step_size))

      inner_state = state['inner_agg_process']
      inner_agg_output = inner_agg_process.next(inner_state, discretized_value)

      undiscretized_agg_value = intrinsics.federated_map(
          undiscretize_fn, (inner_agg_output.result, server_step_size))

      new_state = collections.OrderedDict(
          step_size=server_step_size, inner_agg_process=inner_agg_output.state)
      measurements = collections.OrderedDict(
          deterministic_discretization=inner_agg_output.measurements)

      if self._distortion_aggregation_factory is not None:
        distortions = intrinsics.federated_map(distortion_measurement_fn,
                                               (value, client_step_size))
        aggregate_distortion = distortion_aggregation_process.next(
            distortion_aggregation_process.initialize(), distortions).result
        measurements['distortion'] = aggregate_distortion

      return measured_process.MeasuredProcessOutput(
          state=intrinsics.federated_zip(new_state),
          result=undiscretized_agg_value,
          measurements=intrinsics.federated_zip(measurements))

    return aggregation_process.AggregationProcess(init_fn, next_fn)
예제 #3
0
def serialize_jax_computation(traced_fn, arg_fn, parameter_type,
                              context_stack):
    """Serializes a Python function containing JAX code as a TFF computation.

  Args:
    traced_fn: The Python function containing JAX code to be traced by JAX and
      serialized as a TFF computation containing XLA code.
    arg_fn: An unpacking function that takes a TFF argument, and returns a combo
      of (args, kwargs) to invoke `traced_fn` with (e.g., as the one constructed
      by `function_utils.create_argument_unpacking_fn`).
    parameter_type: An instance of `computation_types.Type` that represents the
      TFF type of the computation parameter, or `None` if the function does not
      take any parameters.
    context_stack: The context stack to use during serialization.

  Returns:
    An instance of `pb.Computation` with the constructed computation.

  Raises:
    TypeError: if the arguments are of the wrong types.
  """
    py_typecheck.check_callable(traced_fn)
    py_typecheck.check_callable(arg_fn)
    py_typecheck.check_type(context_stack, context_stack_base.ContextStack)

    if parameter_type is not None:
        parameter_type = computation_types.to_type(parameter_type)
        packed_arg = _tff_type_to_xla_serializer_arg(parameter_type)
    else:
        packed_arg = None

    args, kwargs = arg_fn(packed_arg)

    # While the fake parameters are fed via args/kwargs during serialization,
    # it is possible for them to get reordered in the actual generated XLA code.
    # We use here the same flattening function as that one, which is used by
    # the JAX serializer to determine the ordering and allow it to be captured
    # in the parameter binding. We do not need to do anything special for the
    # results, since the results, if multiple, are always returned as a tuple.
    flattened_obj, _ = jax.tree_util.tree_flatten((args, kwargs))
    tensor_indexes = list(np.argsort([x.tensor_index for x in flattened_obj]))

    context = jax_computation_context.JaxComputationContext()
    with context_stack.install(context):
        tracer_callable = jax.xla_computation(traced_fn,
                                              tuple_args=True,
                                              return_shape=True)
        compiled_xla, returned_shape = tracer_callable(*args, **kwargs)

    if isinstance(returned_shape, jax.ShapeDtypeStruct):
        returned_type_spec = _jax_shape_dtype_struct_to_tff_tensor(
            returned_shape)
    else:
        returned_type_spec = computation_types.to_type(
            structure.map_structure(
                _jax_shape_dtype_struct_to_tff_tensor,
                structure.from_container(returned_shape, recursive=True)))

    computation_type = computation_types.FunctionType(parameter_type,
                                                      returned_type_spec)
    return xla_serialization.create_xla_tff_computation(
        compiled_xla, tensor_indexes, computation_type)
예제 #4
0
def save(model: model_lib.Model, path: str, input_type=None) -> None:
    """Serializes `model` as a TensorFlow SavedModel to `path`.

  The resulting SavedModel will contain the default serving signature, which
  can be used with the TFLite converter to create a TFLite flatbuffer for
  inference.

  NOTE: The model returned by `tff.learning.models.load` will _not_ be the same
  Python type as the saved model. If the model serialized using this method is
  a subclass of `tff.learning.Model`, that subclass is _not_ returned. All
  method behavior is retained, but the Python type does not cross serialization
  boundaries. The return type of `metric_finalizers` will be an OrderedDict of
  str to `tff.tf_computation` (annotated TFF computations) which could be
  different from that of the model before serialization.

  Args:
    model: The `tff.learning.Model` to save.
    path: The `str` directory path to serialize the model to.
    input_type: An optional structure of `tf.TensorSpec`s representing the
      expected input of `model.predict_on_batch`, to override reading from
      `model.input_spec`. Typically this will be similar to `model.input_spec`,
      with any example labels removed. If None, default to
      `model.input_spec['x']` if the input_spec is a mapping, otherwise default
      to `model.input_spec[0]`.
  """
    py_typecheck.check_type(model, model_lib.Model)
    py_typecheck.check_type(path, str)
    if not path:
        raise ValueError('`path` must be a non-empty string, cannot serialize '
                         'models without an output path.')
    if isinstance(model, _LoadedSavedModel):
        # If we're saving a previously loaded model, we can simply use the module
        # already internal to the Model.
        _save_tensorflow_module(model._loaded_module, path)  # pylint: disable=protected-access
        return

    m = tf.Module()
    # We prefixed with `tff_` because `trainable_variables` is an attribute
    # reserved by `tf.Module`.
    m.tff_trainable_variables = model.trainable_variables
    m.tff_non_trainable_variables = model.non_trainable_variables
    m.tff_local_variables = model.local_variables
    # Serialize forward_pass. We must get two concrete versions of the
    # function, as the `training` argument is a Python value that changes the
    # graph computation. We serialize the output type so that we can repack the
    # flattened values after loaded the saved model.
    forward_pass_training = _make_concrete_flat_output_fn(
        functools.partial(model.forward_pass, training=True), model.input_spec)
    m.flat_forward_pass_training = forward_pass_training[0]
    m.forward_pass_training_type_spec = tf.Variable(
        forward_pass_training[1].SerializeToString(deterministic=True),
        trainable=False)

    forward_pass_inference = _make_concrete_flat_output_fn(
        functools.partial(model.forward_pass, training=False),
        model.input_spec)
    m.flat_forward_pass_inference = forward_pass_inference[0]
    m.forward_pass_inference_type_spec = tf.Variable(
        forward_pass_inference[1].SerializeToString(deterministic=True),
        trainable=False)
    # Get model prediction input type. If `None`, default to assuming the 'x' key
    # or first element of the model input spec is the input.
    if input_type is None:
        if isinstance(model.input_spec, collections.abc.Mapping):
            input_type = model.input_spec['x']
        else:
            input_type = model.input_spec[0]
    # Serialize predict_on_batch. We must get two concrete versions of the
    # function, as the `training` argument is a Python value that changes the
    # graph computation.
    predict_on_batch_training = _make_concrete_flat_output_fn(
        functools.partial(model.predict_on_batch, training=True), input_type)
    m.predict_on_batch_training = predict_on_batch_training[0]
    m.predict_on_batch_training_type_spec = tf.Variable(
        predict_on_batch_training[1].SerializeToString(deterministic=True),
        trainable=False)
    predict_on_batch_inference = _make_concrete_flat_output_fn(
        functools.partial(model.predict_on_batch, training=False), input_type)
    m.predict_on_batch_inference = predict_on_batch_inference[0]
    m.predict_on_batch_inference_type_spec = tf.Variable(
        predict_on_batch_inference[1].SerializeToString(deterministic=True),
        trainable=False)

    # Serialize the report_local_unfinalized_metrics tf.function.
    m.report_local_unfinalized_metrics = (
        model.report_local_unfinalized_metrics.get_concrete_function())

    # Serialize the metric_finalizers as `tf.Variable`s.
    m.serialized_metric_finalizers = collections.OrderedDict()

    def serialize_metric_finalizer(finalizer, metric_type):
        finalizer_computation = tensorflow_computation.tf_computation(
            finalizer, metric_type)
        return tf.Variable(computation_serialization.serialize_computation(
            finalizer_computation).SerializeToString(deterministic=True),
                           trainable=False)

    for metric_name, finalizer in model.metric_finalizers().items():
        metric_type = type_conversions.type_from_tensors(
            model.report_local_unfinalized_metrics()[metric_name])
        m.serialized_metric_finalizers[
            metric_name] = serialize_metric_finalizer(finalizer, metric_type)

    # Serialize the TFF values as string variables that contain the serialized
    # protos from the computation or the type.
    m.serialized_input_spec = tf.Variable(type_serialization.serialize_type(
        computation_types.to_type(
            model.input_spec)).SerializeToString(deterministic=True),
                                          trainable=False)

    # Serialize the reset_metrics tf.function.
    try:
        m.reset_metrics = (model.reset_metrics.get_concrete_function())
    except NotImplementedError:
        m.reset_metrics = None

    _save_tensorflow_module(m, path)
예제 #5
0

def create_whimsy_intrinsic_def_federated_secure_sum_bitwidth():
    value = intrinsic_defs.FEDERATED_SECURE_SUM_BITWIDTH
    type_signature = computation_types.FunctionType([
        computation_types.at_clients(tf.int32),
        tf.int32,
    ], computation_types.at_server(tf.int32))
    return value, type_signature


_WHIMSY_SELECT_CLIENT_KEYS_TYPE = computation_types.at_clients(
    computation_types.TensorType(tf.int32, [3]))
_WHIMSY_SELECT_MAX_KEY_TYPE = computation_types.at_server(tf.int32)
_WHIMSY_SELECT_SERVER_STATE_TYPE = computation_types.at_server(tf.string)
_WHIMSY_SELECTED_TYPE = computation_types.to_type((tf.string, tf.int32))
_WHIMSY_SELECT_SELECT_FN_TYPE = computation_types.FunctionType(
    (tf.string, tf.int32), _WHIMSY_SELECTED_TYPE)
_WHIMSY_SELECT_RESULT_TYPE = computation_types.at_clients(
    computation_types.SequenceType(_WHIMSY_SELECTED_TYPE))
_WHIMSY_SELECT_TYPE = computation_types.FunctionType([
    _WHIMSY_SELECT_CLIENT_KEYS_TYPE,
    _WHIMSY_SELECT_MAX_KEY_TYPE,
    _WHIMSY_SELECT_SERVER_STATE_TYPE,
    _WHIMSY_SELECT_SELECT_FN_TYPE,
], _WHIMSY_SELECT_RESULT_TYPE)
_WHIMSY_SELECT_NUM_CLIENTS = 3


def create_whimsy_intrinsic_def_federated_secure_select():
    return intrinsic_defs.FEDERATED_SECURE_SELECT, _WHIMSY_SELECT_TYPE
예제 #6
0
 def test_structure_of_tensors(self):
     example_type = computation_types.to_type(
         collections.OrderedDict(
             a=TensorType(tf.int32, [3]),
             b=[TensorType(tf.float32),
                TensorType(tf.bool)]))
     merge_computation = sampling._build_merge_samples_computation(
         example_type, sample_size=5)
     reservoir_type = sampling._build_reservoir_type(example_type)
     expected_type = FunctionType(parameter=collections.OrderedDict(
         a=reservoir_type, b=reservoir_type),
                                  result=reservoir_type)
     self.assert_types_identical(merge_computation.type_signature,
                                 expected_type)
     reservoir_a = sampling._build_initial_sample_reservoir(example_type,
                                                            seed=TEST_SEED)
     reservoir_a['random_values'] = [1, 3, 5]
     reservoir_a['samples'] = collections.OrderedDict(
         a=[[0, 1, 2], [1, 2, 3], [2, 3, 4]],
         b=[[0.0, 1.0, 2.0], [True, False, True]])
     with self.subTest('downsample'):
         reservoir_b = sampling._build_initial_sample_reservoir(
             example_type, seed=TEST_SEED + 1)
         reservoir_b['random_values'] = [2, 4, 6, 8]
         reservoir_b['samples'] = collections.OrderedDict(
             a=[[0, -1, -2], [-1, -2, -3], [-2, -3, -4], [-3, -4, -5]],
             b=[[-1., -2., -3., -4.], [True, False, False, True]])
         merged_reservoir = merge_computation(reservoir_a, reservoir_b)
         self.assertAllEqual(
             merged_reservoir,
             collections.OrderedDict(
                 # Arbitrarily take seeds from `a`, discarded later.
                 random_seed=tf.convert_to_tensor((TEST_SEED, TEST_SEED)),
                 random_values=[3, 5, 4, 6, 8],
                 samples=collections.OrderedDict(
                     a=[[1, 2, 3], [2, 3, 4], [-1, -2, -3], [-2, -3, -4],
                        [-3, -4, -5]],
                     b=[[1., 2., -2., -3., -4.],
                        [False, True, False, False, True]])))
     with self.subTest('keep_all'):
         reservoir_b = sampling._build_initial_sample_reservoir(
             example_type, seed=TEST_SEED + 1)
         reservoir_b['random_values'] = [2]
         reservoir_b['samples'] = collections.OrderedDict(a=[[0, -1, -2]],
                                                          b=[[-1.0],
                                                             [True]])
         # We select the value from reservoir_b because its random_value was
         # higher.
         merged_reservoir = merge_computation(reservoir_a, reservoir_b)
         self.assertAllEqual(
             merged_reservoir,
             collections.OrderedDict(
                 # Arbitrarily take seeds from `a`, discarded later.
                 random_seed=tf.convert_to_tensor((TEST_SEED, TEST_SEED)),
                 random_values=[1, 3, 5, 2],
                 samples=collections.OrderedDict(
                     a=[[0, 1, 2], [1, 2, 3], [2, 3, 4], [-0, -1, -2]],
                     b=[[0., 1., 2., -1.], [True, False, True, True]])))
     with self.subTest('tie_breakers'):
         # In case of tie, we take the as many values from `a` first.
         reservoir_b = sampling._build_initial_sample_reservoir(
             example_type, seed=TEST_SEED)
         reservoir_b['random_values'] = [5, 5, 5, 5, 5]  # all tied with `a`
         reservoir_b['samples'] = collections.OrderedDict(a=[[-1, -1, -1]] *
                                                          5,
                                                          b=[[-1] * 5,
                                                             [False] * 5])
         merged_reservoir = merge_computation(reservoir_a, reservoir_b)
         self.assertAllEqual(
             merged_reservoir,
             collections.OrderedDict(random_seed=tf.convert_to_tensor(
                 (TEST_SEED, TEST_SEED)),
                                     random_values=[5, 5, 5, 5, 5],
                                     samples=collections.OrderedDict(
                                         a=[[2, 3, 4]] + [[-1, -1, -1]] * 4,
                                         b=[[2] + [-1] * 4,
                                            [True] + [False] * 4])))
예제 #7
0
def to_value(
    arg: Any,
    type_spec,
    parameter_type_hint=None,
) -> Value:
    """Converts the argument into an instance of the abstract class `tff.Value`.

  Instances of `tff.Value` represent TFF values that appear internally in
  federated computations. This helper function can be used to wrap a variety of
  Python objects as `tff.Value` instances to allow them to be passed as
  arguments, used as functions, or otherwise manipulated within bodies of
  federated computations.

  At the moment, the supported types include:

  * Simple constants of `str`, `int`, `float`, and `bool` types, mapped to
    values of a TFF tensor type.

  * Numpy arrays (`np.ndarray` objects), also mapped to TFF tensors.

  * Dictionaries (`collections.OrderedDict` and unordered `dict`), `list`s,
    `tuple`s, `namedtuple`s, and `Struct`s, all of which are mapped to
    TFF tuple type.

  * Computations (constructed with either the `tff.tf_computation` or with the
    `tff.federated_computation` decorator), typically mapped to TFF functions.

  * Placement literals (`tff.CLIENTS`, `tff.SERVER`), mapped to values of the
    TFF placement type.

  This function is also invoked when attempting to execute a TFF computation.
  All arguments supplied in the invocation are converted into TFF values prior
  to execution. The types of Python objects that can be passed as arguments to
  computations thus matches the types listed here.

  Args:
    arg: An instance of one of the Python types that are convertible to TFF
      values (instances of `tff.Value`).
    type_spec: An optional type specifier that allows for disambiguating the
      target type (e.g., when two TFF types can be mapped to the same Python
      representations). If not specified, TFF tried to determine the type of the
      TFF value automatically.
    parameter_type_hint: An optional `tff.Type` or value convertible to it by
      `tff.to_type()` which specifies an argument type to use in the case that
      `arg` is a `function_utils.PolymorphicFunction`.

  Returns:
    An instance of `tff.Value` as described above.

  Raises:
    TypeError: if `arg` is of an unsupported type, or of a type that does not
      match `type_spec`. Raises explicit error message if TensorFlow constructs
      are encountered, as TensorFlow code should be sealed away from TFF
      federated context.
  """
    if type_spec is not None:
        type_spec = computation_types.to_type(type_spec)
    if isinstance(arg, Value):
        result = arg
    elif isinstance(arg, building_blocks.ComputationBuildingBlock):
        result = Value(arg)
    elif isinstance(arg, placements.PlacementLiteral):
        result = Value(building_blocks.Placement(arg))
    elif isinstance(
            arg,
        (computation_base.Computation, function_utils.PolymorphicFunction)):
        if isinstance(arg, function_utils.PolymorphicFunction):
            if parameter_type_hint is None:
                raise TypeError(
                    'Polymorphic computations cannot be converted to `tff.Value`s '
                    'without a type hint. Consider explicitly specifying the '
                    'argument types of a computation before passing it to a '
                    'function that requires a `tff.Value` (such as a TFF intrinsic '
                    'like `federated_map`). If you are a TFF developer and think '
                    'this should be supported, consider providing `parameter_type_hint` '
                    'as an argument to the encompassing `to_value` conversion.'
                )
            parameter_type_hint = computation_types.to_type(
                parameter_type_hint)
            arg = arg.fn_for_argument_type(parameter_type_hint)
        py_typecheck.check_type(arg, computation_base.Computation)
        result = Value(arg.to_compiled_building_block())
    elif type_spec is not None and type_spec.is_sequence():
        result = _wrap_sequence_as_value(arg, type_spec.element)
    elif isinstance(arg, structure.Struct):
        items = structure.iter_elements(arg)
        result = _dictlike_items_to_value(items, type_spec, None)
    elif py_typecheck.is_named_tuple(arg):
        items = arg._asdict().items()
        result = _dictlike_items_to_value(items, type_spec, type(arg))
    elif py_typecheck.is_attrs(arg):
        items = attr.asdict(arg,
                            dict_factory=collections.OrderedDict,
                            recurse=False).items()
        result = _dictlike_items_to_value(items, type_spec, type(arg))
    elif isinstance(arg, dict):
        if isinstance(arg, collections.OrderedDict):
            items = arg.items()
        else:
            items = sorted(arg.items())
        result = _dictlike_items_to_value(items, type_spec, type(arg))
    elif isinstance(arg, (tuple, list)):
        items = zip(itertools.repeat(None), arg)
        result = _dictlike_items_to_value(items, type_spec, type(arg))
    elif isinstance(arg, tensorflow_utils.TENSOR_REPRESENTATION_TYPES):
        result = _wrap_constant_as_value(arg)
    elif isinstance(arg, (tf.Tensor, tf.Variable)):
        raise TypeError(
            'TensorFlow construct {} has been encountered in a federated '
            'context. TFF does not support mixing TF and federated orchestration '
            'code. Please wrap any TensorFlow constructs with '
            '`tff.tf_computation`.'.format(arg))
    else:
        raise TypeError(
            'Unable to interpret an argument of type {} as a `tff.Value`.'.
            format(py_typecheck.type_string(type(arg))))
    py_typecheck.check_type(result, Value)
    if (type_spec is not None
            and not type_spec.is_assignable_from(result.type_signature)):
        raise TypeError(
            'The supplied argument maps to TFF type {}, which is incompatible with '
            'the requested type {}.'.format(result.type_signature, type_spec))
    return result
예제 #8
0
    def test_increasing_zero_clip_sum(self):
        # Tests when zeroing and clipping are performed with non-integer clips.
        # Zeroing norm grows by 0.75 each time, clipping norm grows by 0.25.

        @computations.federated_computation(_float_at_server,
                                            _float_at_clients)
        def zeroing_next_fn(state, value):
            del value
            return intrinsics.federated_map(
                computations.tf_computation(lambda x: x + 0.75, tf.float32),
                state)

        @computations.federated_computation(_float_at_server,
                                            _float_at_clients)
        def clipping_next_fn(state, value):
            del value
            return intrinsics.federated_map(
                computations.tf_computation(lambda x: x + 0.25, tf.float32),
                state)

        zeroing_norm_process = estimation_process.EstimationProcess(
            _test_init_fn, zeroing_next_fn, _test_report_fn)
        clipping_norm_process = estimation_process.EstimationProcess(
            _test_init_fn, clipping_next_fn, _test_report_fn)

        factory = robust.zeroing_factory(zeroing_norm_process,
                                         _clipped_sum(clipping_norm_process))

        value_type = computation_types.to_type(tf.float32)
        process = factory.create(value_type)

        state = process.initialize()

        client_data = [1.0, 2.0, 3.0]
        output = process.next(state, client_data)
        self.assertAllClose(1.0, output.measurements['zeroing_norm'])
        self.assertAllClose(1.0,
                            output.measurements['zeroing']['clipping_norm'])
        self.assertEqual(2, output.measurements['zeroed_count'])
        self.assertEqual(0, output.measurements['zeroing']['clipped_count'])
        self.assertAllClose(1.0, output.result)

        output = process.next(output.state, client_data)
        self.assertAllClose(1.75, output.measurements['zeroing_norm'])
        self.assertAllClose(1.25,
                            output.measurements['zeroing']['clipping_norm'])
        self.assertEqual(2, output.measurements['zeroed_count'])
        self.assertEqual(0, output.measurements['zeroing']['clipped_count'])
        self.assertAllClose(1.0, output.result)

        output = process.next(output.state, client_data)
        self.assertAllClose(2.5, output.measurements['zeroing_norm'])
        self.assertAllClose(1.5,
                            output.measurements['zeroing']['clipping_norm'])
        self.assertEqual(1, output.measurements['zeroed_count'])
        self.assertEqual(1, output.measurements['zeroing']['clipped_count'])
        self.assertAllClose(2.5, output.result)

        output = process.next(output.state, client_data)
        self.assertAllClose(3.25, output.measurements['zeroing_norm'])
        self.assertAllClose(1.75,
                            output.measurements['zeroing']['clipping_norm'])
        self.assertEqual(0, output.measurements['zeroed_count'])
        self.assertEqual(2, output.measurements['zeroing']['clipped_count'])
        self.assertAllClose(4.5, output.result)

        output = process.next(output.state, client_data)
        self.assertAllClose(4.0, output.measurements['zeroing_norm'])
        self.assertAllClose(2.0,
                            output.measurements['zeroing']['clipping_norm'])
        self.assertEqual(0, output.measurements['zeroed_count'])
        self.assertEqual(1, output.measurements['zeroing']['clipped_count'])
        self.assertAllClose(5.0, output.result)
예제 #9
0
from tensorflow_federated.python.core.impl.tensorflow_context import tensorflow_computation
from tensorflow_federated.python.core.impl.types import computation_types
from tensorflow_federated.python.core.impl.types import placements
from tensorflow_federated.python.core.templates import errors
from tensorflow_federated.python.core.templates import measured_process
from tensorflow_federated.python.learning import model_utils
from tensorflow_federated.python.learning.templates import client_works

SERVER_INT = computation_types.FederatedType(tf.int32, placements.SERVER)
SERVER_FLOAT = computation_types.FederatedType(tf.float32, placements.SERVER)
CLIENTS_FLOAT_SEQUENCE = computation_types.FederatedType(
    computation_types.SequenceType(tf.float32), placements.CLIENTS)
CLIENTS_FLOAT = computation_types.FederatedType(tf.float32, placements.CLIENTS)
CLIENTS_INT = computation_types.FederatedType(tf.int32, placements.CLIENTS)
MODEL_WEIGHTS_TYPE = computation_types.at_clients(
    computation_types.to_type(model_utils.ModelWeights(tf.float32, ())))
MeasuredProcessOutput = measured_process.MeasuredProcessOutput


def server_zero():
    return intrinsics.federated_value(0, placements.SERVER)


def client_one():
    return intrinsics.federated_value(1.0, placements.CLIENTS)


def federated_add(a, b):
    return intrinsics.federated_map(
        tensorflow_computation.tf_computation(lambda x, y: x + y), (a, b))
예제 #10
0
 def test_tff_value_types_raise_on(self, value_type):
     ddp_factory = _make_test_factory()
     value_type = computation_types.to_type(value_type)
     with self.assertRaisesRegex(TypeError, 'Expected `value_type` to be'):
         ddp_factory.create(value_type)
예제 #11
0
 def test_component_tensor_dtypes_raise_on(self, value_type):
     test_factory = _make_test_factory()
     value_type = computation_types.to_type(value_type)
     with self.assertRaisesRegex(TypeError,
                                 'must all be integers or floats'):
         test_factory.create(value_type)
예제 #12
0
    def test_type_properties(self, value_type, mechanism):
        ddp_factory = _make_test_factory(mechanism=mechanism)
        self.assertIsInstance(ddp_factory,
                              factory.UnweightedAggregationFactory)
        value_type = computation_types.to_type(value_type)
        process = ddp_factory.create(value_type)
        self.assertIsInstance(process, aggregation_process.AggregationProcess)

        # The state is a nested object with component factory states. Construct
        # test factories directly and compare the signatures.
        modsum_f = secure.SecureModularSumFactory(2**15, True)

        if mechanism == 'distributed_dgauss':
            dp_query = tfp.DistributedDiscreteGaussianSumQuery(
                l2_norm_bound=10.0, local_stddev=10.0)
        else:
            dp_query = tfp.DistributedSkellamSumQuery(l1_norm_bound=10.0,
                                                      l2_norm_bound=10.0,
                                                      local_stddev=10.0)

        dp_f = differential_privacy.DifferentiallyPrivateFactory(
            dp_query, modsum_f)
        discrete_f = discretization.DiscretizationFactory(dp_f)
        l2clip_f = robust.clipping_factory(clipping_norm=10.0,
                                           inner_agg_factory=discrete_f)
        rot_f = rotation.HadamardTransformFactory(inner_agg_factory=l2clip_f)
        expected_process = concat.concat_factory(rot_f).create(value_type)

        # Check init_fn/state.
        expected_init_type = expected_process.initialize.type_signature
        expected_state_type = expected_init_type.result
        actual_init_type = process.initialize.type_signature
        self.assertTrue(actual_init_type.is_equivalent_to(expected_init_type))

        # Check next_fn/measurements.
        tensor2type = type_conversions.type_from_tensors
        discrete_state = discrete_f.create(
            computation_types.to_type(tf.float32)).initialize()
        dp_query_state = dp_query.initial_global_state()
        dp_query_metrics_type = tensor2type(
            dp_query.derive_metrics(dp_query_state))
        expected_measurements_type = collections.OrderedDict(
            l2_clip=robust.NORM_TF_TYPE,
            scale_factor=tensor2type(discrete_state['scale_factor']),
            scaled_inflated_l2=tensor2type(dp_query_state.l2_norm_bound),
            scaled_local_stddev=tensor2type(dp_query_state.local_stddev),
            actual_num_clients=tf.int32,
            padded_dim=tf.int32,
            dp_query_metrics=dp_query_metrics_type)
        expected_next_type = computation_types.FunctionType(
            parameter=collections.OrderedDict(
                state=expected_state_type,
                value=computation_types.at_clients(value_type)),
            result=measured_process.MeasuredProcessOutput(
                state=expected_state_type,
                result=computation_types.at_server(value_type),
                measurements=computation_types.at_server(
                    expected_measurements_type)))
        actual_next_type = process.next.type_signature
        self.assertTrue(actual_next_type.is_equivalent_to(expected_next_type))
        try:
            static_assert.assert_not_contains_unsecure_aggregation(
                process.next)
        except:  # pylint: disable=bare-except
            self.fail('Factory returned an AggregationProcess containing '
                      'non-secure aggregation.')
예제 #13
0
  def test_call_returns_result(self):

    class TestContext(context_base.Context):

      def ingest(self, val, type_spec):
        return val

      def invoke(self, comp, arg):
        return 'name={},type={},arg={},unpack={}'.format(
            comp.name, comp.type_signature.parameter, arg, comp.unpack)

    class TestContextStack(context_stack_base.ContextStack):

      def __init__(self):
        super().__init__()
        self._context = TestContext()

      @property
      def current(self):
        return self._context

      def install(self, ctx):
        del ctx  # Unused
        return self._context

    context_stack = TestContextStack()

    class TestFunction(computation_impl.ConcreteComputation):

      def __init__(self, name, unpack, parameter_type):
        self._name = name
        self._unpack = unpack
        type_signature = computation_types.FunctionType(parameter_type,
                                                        tf.string)
        test_proto = pb.Computation(
            type=type_serialization.serialize_type(type_signature))
        super().__init__(test_proto, context_stack, type_signature)

      @property
      def name(self):
        return self._name

      @property
      def unpack(self):
        return self._unpack

    class TestFunctionFactory(object):

      def __init__(self):
        self._count = 0

      def __call__(self, parameter_type, unpack):
        self._count = self._count + 1
        return TestFunction(str(self._count), str(unpack), parameter_type)

    fn = function_utils.PolymorphicComputation(TestFunctionFactory())

    self.assertEqual(fn(10), 'name=1,type=<int32>,arg=<10>,unpack=True')
    self.assertEqual(
        fn(20, x=True),
        'name=2,type=<int32,x=bool>,arg=<20,x=True>,unpack=True')
    fn_with_bool_arg = fn.fn_for_argument_type(
        computation_types.to_type(tf.bool))
    self.assertEqual(
        fn_with_bool_arg(True), 'name=3,type=bool,arg=True,unpack=None')
    self.assertEqual(
        fn(30, x=40), 'name=4,type=<int32,x=int32>,arg=<30,x=40>,unpack=True')
    self.assertEqual(fn(50), 'name=1,type=<int32>,arg=<50>,unpack=True')
    self.assertEqual(
        fn(0, x=False),
        'name=2,type=<int32,x=bool>,arg=<0,x=False>,unpack=True')
    fn_with_bool_arg = fn.fn_for_argument_type(
        computation_types.to_type(tf.bool))
    self.assertEqual(
        fn_with_bool_arg(False), 'name=3,type=bool,arg=False,unpack=None')
    self.assertEqual(
        fn(60, x=70), 'name=4,type=<int32,x=int32>,arg=<60,x=70>,unpack=True')
예제 #14
0
      te.encoders.hadamard_quantization(8), value_spec)


def _one_over_n_encoder_fn(value_spec):
  return te.encoders.as_gather_encoder(
      te.core.EncoderComposer(te.testing.PlusOneOverNEncodingStage()).make(),
      value_spec)


def _state_update_encoder_fn(value_spec):
  return te.encoders.as_gather_encoder(
      te.core.EncoderComposer(StateUpdateTensorsEncodingStage()).make(),
      value_spec)


_test_struct_type = computation_types.to_type(((tf.float32, (20,)), tf.float32))


class EncodedSumFactoryComputationTest(test_case.TestCase,
                                       parameterized.TestCase):

  @parameterized.named_parameters(
      ('identity_from_encoder_fn', _identity_encoder_fn),
      ('uniform_from_encoder_fn', _uniform_encoder_fn),
      ('hadamard_from_encoder_fn', _hadamard_encoder_fn),
      ('one_over_n_from_encoder_fn', _one_over_n_encoder_fn),
      ('state_update_from_encoder_fn', _state_update_encoder_fn),
  )
  def test_type_properties(self, encoder_fn):
    encoded_f = encoded.EncodedSumFactory(encoder_fn)
    self.assertIsInstance(encoded_f, factory.UnweightedAggregationFactory)
예제 #15
0
    async def create_value(self, value, type_spec=None):
        """Creates a value in this executor.

    The following kinds of `value` are supported as the input:

    * An instance of TFF computation proto containing one of the supported
      sequence intrinsics as its sole body.

    * An instance of eager TF dataset.

    * Anything that is supported by the target executor (as a pass-through).

    * A nested structure of any of the above.

    Args:
      value: The input for which to create a value.
      type_spec: An optional TFF type (required if `value` is not an instance of
        `typed_object.TypedObject`, otherwise it can be `None`).

    Returns:
      An instance of `SequenceExecutorValue` that represents the embedded value.
    """
        if type_spec is None:
            py_typecheck.check_type(value, typed_object.TypedObject)
            type_spec = value.type_signature
        else:
            type_spec = computation_types.to_type(type_spec)
        if isinstance(type_spec, computation_types.SequenceType):
            return SequenceExecutorValue(
                _SequenceFromPayload(value, type_spec), type_spec)
        if isinstance(value, pb.Computation):
            value_type = type_serialization.deserialize_type(value.type)
            value_type.check_equivalent_to(type_spec)
            which_computation = value.WhichOneof('computation')
            # NOTE: If not a supported type of intrinsic, we let it fall through and
            # be handled by embedding in the target executor (below).
            if which_computation == 'intrinsic':
                intrinsic_def = intrinsic_defs.uri_to_intrinsic_def(
                    value.intrinsic.uri)
                if intrinsic_def is None:
                    raise ValueError(
                        'Encountered an unrecognized intrinsic "{}".'.format(
                            value.intrinsic.uri))
                op_type = SequenceExecutor._SUPPORTED_INTRINSIC_TO_SEQUENCE_OP.get(
                    intrinsic_def.uri)
                if op_type is not None:
                    type_analysis.check_concrete_instance_of(
                        type_spec, intrinsic_def.type_signature)
                    op = op_type(type_spec)
                    return SequenceExecutorValue(op, type_spec)
        if isinstance(type_spec, computation_types.StructType):
            if not isinstance(value, structure.Struct):
                value = structure.from_container(value)
            elements = structure.flatten(value)
            element_types = structure.flatten(type_spec)
            flat_embedded_vals = await asyncio.gather(*[
                self.create_value(el, el_type)
                for el, el_type in zip(elements, element_types)
            ])
            embedded_struct = structure.pack_sequence_as(
                value, flat_embedded_vals)
            return await self.create_struct(embedded_struct)
        target_value = await self._target_executor.create_value(
            value, type_spec)
        return SequenceExecutorValue(target_value, type_spec)
예제 #16
0
def infer_type(arg: Any) -> Optional[computation_types.Type]:
    """Infers the TFF type of the argument (a `computation_types.Type` instance).

  Warning: This function is only partially implemented.

  The kinds of arguments that are currently correctly recognized:
  * tensors, variables, and data sets
  * things that are convertible to tensors (including `numpy` arrays, builtin
    types, as well as `list`s and `tuple`s of any of the above, etc.)
  * nested lists, `tuple`s, `namedtuple`s, anonymous `tuple`s, `dict`,
    `OrderedDict`s, `dataclasses`, `attrs` classes, and `tff.TypedObject`s

  Args:
    arg: The argument, the TFF type of which to infer.

  Returns:
    Either an instance of `computation_types.Type`, or `None` if the argument is
    `None`.
  """
    if arg is None:
        return None
    elif isinstance(arg, typed_object.TypedObject):
        return arg.type_signature
    elif tf.is_tensor(arg):
        # `tf.is_tensor` returns true for some things that are not actually single
        # `tf.Tensor`s, including `tf.SparseTensor`s and `tf.RaggedTensor`s.
        if isinstance(arg, tf.RaggedTensor):
            return computation_types.StructWithPythonType(
                (('flat_values', infer_type(arg.flat_values)),
                 ('nested_row_splits', infer_type(arg.nested_row_splits))),
                tf.RaggedTensor)
        elif isinstance(arg, tf.SparseTensor):
            return computation_types.StructWithPythonType(
                (('indices', infer_type(arg.indices)),
                 ('values', infer_type(arg.values)),
                 ('dense_shape', infer_type(arg.dense_shape))),
                tf.SparseTensor)
        else:
            return computation_types.TensorType(arg.dtype.base_dtype,
                                                arg.shape)
    elif isinstance(arg, TF_DATASET_REPRESENTATION_TYPES):
        element_type = computation_types.to_type(arg.element_spec)
        return computation_types.SequenceType(element_type)
    elif isinstance(arg, structure.Struct):
        return computation_types.StructType([
            (k, infer_type(v)) if k else infer_type(v)
            for k, v in structure.iter_elements(arg)
        ])
    elif py_typecheck.is_attrs(arg):
        items = named_containers.attrs_class_to_odict(arg).items()
        return computation_types.StructWithPythonType([(k, infer_type(v))
                                                       for k, v in items],
                                                      type(arg))
    elif py_typecheck.is_dataclass(arg):
        items = named_containers.dataclass_to_odict(arg).items()
        return computation_types.StructWithPythonType([(k, infer_type(v))
                                                       for k, v in items],
                                                      type(arg))
    elif py_typecheck.is_named_tuple(arg):
        # In Python 3.8 and later `_asdict` no longer return OrderedDict, rather a
        # regular `dict`.
        items = collections.OrderedDict(arg._asdict())
        return computation_types.StructWithPythonType(
            [(k, infer_type(v)) for k, v in items.items()], type(arg))
    elif isinstance(arg, dict):
        if isinstance(arg, collections.OrderedDict):
            items = arg.items()
        else:
            items = sorted(arg.items())
        return computation_types.StructWithPythonType([(k, infer_type(v))
                                                       for k, v in items],
                                                      type(arg))
    elif isinstance(arg, (tuple, list)):
        elements = []
        all_elements_named = True
        for element in arg:
            all_elements_named &= py_typecheck.is_name_value_pair(element)
            elements.append(infer_type(element))
        # If this is a tuple of (name, value) pairs, the caller most likely intended
        # this to be a StructType, so we avoid storing the Python container.
        if elements and all_elements_named:
            return computation_types.StructType(elements)
        else:
            return computation_types.StructWithPythonType(elements, type(arg))
    elif isinstance(arg, str):
        return computation_types.TensorType(tf.string)
    elif isinstance(arg, (np.generic, np.ndarray)):
        return computation_types.TensorType(tf.dtypes.as_dtype(arg.dtype),
                                            arg.shape)
    else:
        arg_type = type(arg)
        if arg_type is bool:
            return computation_types.TensorType(tf.bool)
        elif arg_type is int:
            # Chose the integral type based on value.
            if arg > tf.int64.max or arg < tf.int64.min:
                raise TypeError(
                    'No integral type support for values outside range '
                    f'[{tf.int64.min}, {tf.int64.max}]. Got: {arg}')
            elif arg > tf.int32.max or arg < tf.int32.min:
                return computation_types.TensorType(tf.int64)
            else:
                return computation_types.TensorType(tf.int32)
        elif arg_type is float:
            return computation_types.TensorType(tf.float32)
        else:
            # Now fall back onto the heavier-weight processing, as all else failed.
            # Use make_tensor_proto() to make sure to handle it consistently with
            # how TensorFlow is handling values (e.g., recognizing int as int32, as
            # opposed to int64 as in NumPy).
            try:
                # TODO(b/113112885): Find something more lightweight we could use here.
                tensor_proto = tf.make_tensor_proto(arg)
                return computation_types.TensorType(
                    tf.dtypes.as_dtype(tensor_proto.dtype),
                    tf.TensorShape(tensor_proto.tensor_shape))
            except TypeError as e:
                raise TypeError('Could not infer the TFF type of {}.'.format(
                    py_typecheck.type_string(type(arg)))) from e
예제 #17
0
 def test_inner_federated_type_raises(self):
     with self.assertRaisesRegex(TypeError, 'FederatedType'):
         distributors.build_broadcast_process(
             computation_types.to_type([SERVER_FLOAT, SERVER_FLOAT]))
예제 #18
0
def _make_wrapper(clipping_norm: Union[float,
                                       estimation_process.EstimationProcess],
                  inner_agg_factory: factory.AggregationFactory,
                  make_clip_fn: Callable[[factory.ValueType],
                                         computation_base.Computation],
                  attribute_prefix: str) -> factory.AggregationFactory:
  """Constructs an aggregation factory that applies clip_fn before aggregation.

  Args:
    clipping_norm: Either a float (for fixed norm) or an `EstimationProcess`
      (for adaptive norm) that specifies the norm over which the values should
      be clipped.
    inner_agg_factory: A factory specifying the type of aggregation to be done
      after zeroing.
    make_clip_fn: A callable that takes a value type and returns a
      tff.computation specifying the clip operation to apply before aggregation.
    attribute_prefix: A str for prefixing state and measurement names.

  Returns:
    An aggregation factory that applies clip_fn before aggregation.
  """
  py_typecheck.check_type(inner_agg_factory,
                          (factory.UnweightedAggregationFactory,
                           factory.WeightedAggregationFactory))
  py_typecheck.check_type(clipping_norm,
                          (float, estimation_process.EstimationProcess))
  if isinstance(clipping_norm, float):
    clipping_norm_process = _constant_process(clipping_norm)
  else:
    clipping_norm_process = clipping_norm
  _check_norm_process(clipping_norm_process, 'clipping_norm_process')

  # The aggregation factory that will be used to count the number of clipped
  # values at each iteration. For now we are just creating it here, but in
  # the future we may make this customizable to allow DP measurements.
  clipped_count_agg_factory = sum_factory.SumFactory()

  clipped_count_agg_process = clipped_count_agg_factory.create(
      computation_types.to_type(COUNT_TF_TYPE))

  prefix = lambda s: attribute_prefix + s

  def init_fn_impl(inner_agg_process):
    state = collections.OrderedDict([
        (prefix('ing_norm'), clipping_norm_process.initialize()),
        ('inner_agg', inner_agg_process.initialize()),
        (prefix('ed_count_agg'), clipped_count_agg_process.initialize())
    ])
    return intrinsics.federated_zip(state)

  def next_fn_impl(state, value, clip_fn, inner_agg_process, weight=None):
    clipping_norm_state, agg_state, clipped_count_state = state

    clipping_norm = clipping_norm_process.report(clipping_norm_state)

    clients_clipping_norm = intrinsics.federated_broadcast(clipping_norm)

    # TODO(b/163880757): Remove this when server-only metrics are supported.
    clipping_norm = intrinsics.federated_mean(clients_clipping_norm)

    clipped_value, global_norm, was_clipped = intrinsics.federated_map(
        clip_fn, (value, clients_clipping_norm))

    new_clipping_norm_state = clipping_norm_process.next(
        clipping_norm_state, global_norm)

    if weight is None:
      agg_output = inner_agg_process.next(agg_state, clipped_value)
    else:
      agg_output = inner_agg_process.next(agg_state, clipped_value, weight)

    clipped_count_output = clipped_count_agg_process.next(
        clipped_count_state, was_clipped)

    new_state = collections.OrderedDict([
        (prefix('ing_norm'), new_clipping_norm_state),
        ('inner_agg', agg_output.state),
        (prefix('ed_count_agg'), clipped_count_output.state)
    ])
    measurements = collections.OrderedDict([
        (prefix('ing'), agg_output.measurements),
        (prefix('ing_norm'), clipping_norm),
        (prefix('ed_count'), clipped_count_output.result)
    ])

    return measured_process.MeasuredProcessOutput(
        state=intrinsics.federated_zip(new_state),
        result=agg_output.result,
        measurements=intrinsics.federated_zip(measurements))

  if isinstance(inner_agg_factory, factory.WeightedAggregationFactory):

    class WeightedRobustFactory(factory.WeightedAggregationFactory):
      """`WeightedAggregationFactory` factory for clipping large values."""

      def create(
          self, value_type: factory.ValueType, weight_type: factory.ValueType
      ) -> aggregation_process.AggregationProcess:
        _check_value_type(value_type)
        py_typecheck.check_type(weight_type, factory.ValueType.__args__)

        inner_agg_process = inner_agg_factory.create(value_type, weight_type)
        clip_fn = make_clip_fn(value_type)

        @computations.federated_computation()
        def init_fn():
          return init_fn_impl(inner_agg_process)

        @computations.federated_computation(
            init_fn.type_signature.result,
            computation_types.at_clients(value_type),
            computation_types.at_clients(weight_type))
        def next_fn(state, value, weight):
          return next_fn_impl(state, value, clip_fn, inner_agg_process, weight)

        return aggregation_process.AggregationProcess(init_fn, next_fn)

    return WeightedRobustFactory()
  else:

    class UnweightedRobustFactory(factory.UnweightedAggregationFactory):
      """`UnweightedAggregationFactory` factory for clipping large values."""

      def create(
          self, value_type: factory.ValueType
      ) -> aggregation_process.AggregationProcess:
        _check_value_type(value_type)

        inner_agg_process = inner_agg_factory.create(value_type)
        clip_fn = make_clip_fn(value_type)

        @computations.federated_computation()
        def init_fn():
          return init_fn_impl(inner_agg_process)

        @computations.federated_computation(
            init_fn.type_signature.result,
            computation_types.at_clients(value_type))
        def next_fn(state, value):
          return next_fn_impl(state, value, clip_fn, inner_agg_process)

        return aggregation_process.AggregationProcess(init_fn, next_fn)

    return UnweightedRobustFactory()
예제 #19
0
def build_federated_evaluation(
    model_fn: Callable[[], model_lib.Model],
    broadcast_process: Optional[measured_process.MeasuredProcess] = None,
    use_experimental_simulation_loop: bool = False,
) -> computation_base.Computation:
    """Builds the TFF computation for federated evaluation of the given model.

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`. This method
      must *not* capture TensorFlow tensors or variables and use them. The model
      must be constructed entirely from scratch on each invocation, returning
      the same pre-constructed model each call will result in an error.
    broadcast_process: A `tff.templates.MeasuredProcess` that broadcasts the
      model weights on the server to the clients. It must support the signature
      `(input_values@SERVER -> output_values@CLIENTS)` and have empty state. If
      set to default None, the server model is broadcast to the clients using
      the default tff.federated_broadcast.
    use_experimental_simulation_loop: Controls the reduce loop function for
      input dataset. An experimental reduce loop is used for simulation.

  Returns:
    A federated computation (an instance of `tff.Computation`) that accepts
    model parameters and federated data, and returns the evaluation metrics
    as aggregated by `tff.learning.Model.federated_output_computation`.
  """
    if broadcast_process is not None:
        if not isinstance(broadcast_process, measured_process.MeasuredProcess):
            raise ValueError(
                '`broadcast_process` must be a `MeasuredProcess`, got '
                f'{type(broadcast_process)}.')
        if optimizer_utils.is_stateful_process(broadcast_process):
            raise ValueError(
                'Cannot create a federated evaluation with a stateful '
                'broadcast process, must be stateless, has state: '
                f'{broadcast_process.initialize.type_signature.result!r}')
    # Construct the model first just to obtain the metadata and define all the
    # types needed to define the computations that follow.
    # TODO(b/124477628): Ideally replace the need for stamping throwaway models
    # with some other mechanism.
    with tf.Graph().as_default():
        model = model_fn()
        model_weights_type = model_utils.weights_type_from_model(model)
        batch_type = computation_types.to_type(model.input_spec)

    @computations.tf_computation(model_weights_type, SequenceType(batch_type))
    @tf.function
    def client_eval(incoming_model_weights, dataset):
        """Returns local outputs after evaluting `model_weights` on `dataset`."""
        with tf.init_scope():
            model = model_fn()
        model_weights = model_utils.ModelWeights.from_model(model)
        tf.nest.map_structure(lambda v, t: v.assign(t), model_weights,
                              incoming_model_weights)

        def reduce_fn(num_examples, batch):
            model_output = model.forward_pass(batch, training=False)
            if model_output.num_examples is None:
                # Compute shape from the size of the predictions if model didn't use the
                # batch size.
                return num_examples + tf.shape(model_output.predictions,
                                               out_type=tf.int64)[0]
            else:
                return num_examples + tf.cast(model_output.num_examples,
                                              tf.int64)

        dataset_reduce_fn = dataset_reduce.build_dataset_reduce_fn(
            use_experimental_simulation_loop)
        num_examples = dataset_reduce_fn(
            reduce_fn=reduce_fn,
            dataset=dataset,
            initial_state_fn=lambda: tf.zeros([], dtype=tf.int64))
        return collections.OrderedDict(
            local_outputs=model.report_local_outputs(),
            num_examples=num_examples)

    @computations.federated_computation(
        computation_types.at_server(model_weights_type),
        computation_types.at_clients(SequenceType(batch_type)))
    def server_eval(server_model_weights, federated_dataset):
        if broadcast_process is not None:
            # TODO(b/179091838): Zip the measurements from the broadcast_process with
            # the result of `model.federated_output_computation` below to avoid
            # dropping these metrics.
            broadcast_output = broadcast_process.next(
                broadcast_process.initialize(), server_model_weights)
            client_outputs = intrinsics.federated_map(
                client_eval, (broadcast_output.result, federated_dataset))
        else:
            client_outputs = intrinsics.federated_map(client_eval, [
                intrinsics.federated_broadcast(server_model_weights),
                federated_dataset
            ])
        model_metrics = model.federated_output_computation(
            client_outputs.local_outputs)
        statistics = collections.OrderedDict(
            num_examples=intrinsics.federated_sum(client_outputs.num_examples))
        return intrinsics.federated_zip(
            collections.OrderedDict(eval=model_metrics, stat=statistics))

    return server_eval
예제 #20
0
    def test_build_tf_computations_for_sum(self, encoder_constructor):
        # Tests that the partial computations have matching relevant input-output
        # signatures.
        value_spec = tf.TensorSpec((20, ), tf.float32)
        encoder = te.encoders.as_gather_encoder(encoder_constructor(),
                                                value_spec)

        _, state_type = encoding_utils._build_initial_state_tf_computation(
            encoder)
        value_type = computation_types.to_type(value_spec)
        nest_encoder = encoding_utils._build_tf_computations_for_gather(
            state_type, value_type, encoder)

        self.assertEqual(state_type,
                         nest_encoder.get_params_fn.type_signature.parameter)
        encode_params_type = nest_encoder.get_params_fn.type_signature.result[
            0]
        decode_before_sum_params_type = nest_encoder.get_params_fn.type_signature.result[
            1]
        decode_after_sum_params_type = nest_encoder.get_params_fn.type_signature.result[
            2]

        self.assertEqual(value_type,
                         nest_encoder.encode_fn.type_signature.parameter[0])
        self.assertEqual(encode_params_type,
                         nest_encoder.encode_fn.type_signature.parameter[1])
        self.assertEqual(decode_before_sum_params_type,
                         nest_encoder.encode_fn.type_signature.parameter[2])
        state_update_tensors_type = nest_encoder.encode_fn.type_signature.result[
            2]

        accumulator_type = nest_encoder.zero_fn.type_signature.result
        self.assertEqual(state_update_tensors_type,
                         accumulator_type.state_update_tensors)

        self.assertEqual(
            accumulator_type,
            nest_encoder.accumulate_fn.type_signature.parameter[0])
        self.assertEqual(
            nest_encoder.encode_fn.type_signature.result,
            nest_encoder.accumulate_fn.type_signature.parameter[1])
        self.assertEqual(accumulator_type,
                         nest_encoder.accumulate_fn.type_signature.result)
        self.assertEqual(accumulator_type,
                         nest_encoder.merge_fn.type_signature.parameter[0])
        self.assertEqual(accumulator_type,
                         nest_encoder.merge_fn.type_signature.parameter[1])
        self.assertEqual(accumulator_type,
                         nest_encoder.merge_fn.type_signature.result)
        self.assertEqual(accumulator_type,
                         nest_encoder.report_fn.type_signature.parameter)
        self.assertEqual(accumulator_type,
                         nest_encoder.report_fn.type_signature.result)

        self.assertEqual(
            accumulator_type.values,
            nest_encoder.decode_after_sum_fn.type_signature.parameter[0])
        self.assertEqual(
            decode_after_sum_params_type,
            nest_encoder.decode_after_sum_fn.type_signature.parameter[1])
        self.assertEqual(
            value_type, nest_encoder.decode_after_sum_fn.type_signature.result)

        self.assertEqual(
            state_type,
            nest_encoder.update_state_fn.type_signature.parameter[0])
        self.assertEqual(
            state_update_tensors_type,
            nest_encoder.update_state_fn.type_signature.parameter[1])
        self.assertEqual(state_type,
                         nest_encoder.update_state_fn.type_signature.result)
예제 #21
0
def embed_tensorflow_computation(comp, type_spec=None, device=None):
  """Embeds a TensorFlow computation for use in the eager context.

  Args:
    comp: An instance of `pb.Computation`.
    type_spec: An optional `tff.Type` instance or something convertible to it.
    device: An optional `tf.config.LogicalDevice`.

  Returns:
    Either a one-argument or a zero-argument callable that executes the
    computation in eager mode.

  Raises:
    TypeError: If arguments are of the wrong types, e.g., in `comp` is not a
      TensorFlow computation.
  """
  # TODO(b/134543154): Decide whether this belongs in `tensorflow_utils.py`
  # since it deals exclusively with eager mode. Incubate here, and potentially
  # move there, once stable.

  py_typecheck.check_type(comp, pb.Computation)
  comp = _ensure_comp_runtime_compatible(comp)
  comp_type = type_serialization.deserialize_type(comp.type)
  type_spec = computation_types.to_type(type_spec)
  if type_spec is not None:
    if not type_spec.is_equivalent_to(comp_type):
      raise TypeError('Expected a computation of type {}, got {}.'.format(
          type_spec, comp_type))
  else:
    type_spec = comp_type
  # TODO(b/156302055): Currently, TF will raise on any function returning a
  # `tf.data.Dataset` not pinned to CPU. We should follow up here and remove
  # this gating when we can.
  must_pin_function_to_cpu = type_analysis.contains(type_spec.result,
                                                    lambda t: t.is_sequence())
  which_computation = comp.WhichOneof('computation')
  if which_computation != 'tensorflow':
    unexpected_building_block = building_blocks.ComputationBuildingBlock.from_proto(
        comp)
    raise TypeError('Expected a TensorFlow computation, found {}.'.format(
        unexpected_building_block))

  if type_spec.is_function():
    param_type = type_spec.parameter
    result_type = type_spec.result
  else:
    param_type = None
    result_type = type_spec

  wrapped_fn = _get_wrapped_function_from_comp(comp, must_pin_function_to_cpu,
                                               param_type, device)
  param_fns = []
  if param_type is not None:
    for spec in structure.flatten(type_spec.parameter):
      if spec.is_tensor():
        param_fns.append(lambda x: x)
      else:
        py_typecheck.check_type(spec, computation_types.SequenceType)
        param_fns.append(tf.data.experimental.to_variant)

  result_fns = []
  for spec in structure.flatten(result_type):
    if spec.is_tensor():
      result_fns.append(lambda x: x)
    else:
      py_typecheck.check_type(spec, computation_types.SequenceType)
      tf_structure = type_conversions.type_to_tf_structure(spec.element)

      def fn(x, tf_structure=tf_structure):
        return tf.data.experimental.from_variant(x, tf_structure)

      result_fns.append(fn)

  ops = wrapped_fn.graph.get_operations()

  eager_cleanup_ops = []
  destroy_before_invocation = []
  for op in ops:
    if op.type == 'HashTableV2':
      eager_cleanup_ops += op.outputs
  if eager_cleanup_ops:
    for resource in wrapped_fn.prune(feeds={}, fetches=eager_cleanup_ops)():
      destroy_before_invocation.append(resource)

  lazy_cleanup_ops = []
  destroy_after_invocation = []
  for op in ops:
    if op.type == 'VarHandleOp':
      lazy_cleanup_ops += op.outputs
  if lazy_cleanup_ops:
    for resource in wrapped_fn.prune(feeds={}, fetches=lazy_cleanup_ops)():
      destroy_after_invocation.append(resource)

  def fn_to_return(arg,
                   param_fns=tuple(param_fns),
                   result_fns=tuple(result_fns),
                   result_type=result_type,
                   wrapped_fn=wrapped_fn,
                   destroy_before=tuple(destroy_before_invocation),
                   destroy_after=tuple(destroy_after_invocation)):
    # This double-function pattern works around python late binding, forcing the
    # variables to bind eagerly.
    return _call_embedded_tf(
        arg=arg,
        param_fns=param_fns,
        result_fns=result_fns,
        result_type=result_type,
        wrapped_fn=wrapped_fn,
        destroy_before_invocation=destroy_before,
        destroy_after_invocation=destroy_after)

  # pylint: disable=function-redefined
  if must_pin_function_to_cpu:
    old_fn_to_return = fn_to_return

    def fn_to_return(x):
      with tf.device('cpu'):
        return old_fn_to_return(x)
  elif device is not None:
    old_fn_to_return = fn_to_return

    def fn_to_return(x):
      with tf.device(device.name):
        return old_fn_to_return(x)

  # pylint: enable=function-redefined

  if param_type is not None:
    return lambda arg: fn_to_return(arg)  # pylint: disable=unnecessary-lambda
  else:
    return lambda: fn_to_return(None)
예제 #22
0
  async def create_value(
      self,
      value: Any,
      type_spec: Any = None) -> executor_value_base.ExecutorValue:
    """Creates an embedded value from the given `value` and `type_spec`.

    The kinds of supported `value`s are:

    * An instance of `intrinsic_defs.IntrinsicDef`.

    * An instance of `placements.PlacementLiteral`.

    * An instance of `pb.Computation` if of one of the following kinds:
      intrinsic, lambda, tensorflow, xla, or data.

    * A Python `list` if `type_spec` is a federated type.

      Note: The `value` must be a list even if it is of an `all_equal` type or
      if there is only a single participant associated with the given placement.

    * A Python value if `type_spec` is a non-functional, non-federated type.

    Args:
      value: An object to embed in the executor, one of the supported types
        defined by above.
      type_spec: An optional type convertible to instance of `tff.Type` via
        `tff.to_type`, the type of `value`.

    Returns:
      An instance of `executor_value_base.ExecutorValue` representing a value
      embedded in the `FederatingExecutor` using a particular
      `FederatingStrategy`.

    Raises:
      TypeError: If the `value` and `type_spec` do not match.
      ValueError: If `value` is not a kind supported by the
        `FederatingExecutor`.
    """
    type_spec = computation_types.to_type(type_spec)
    if isinstance(value, intrinsic_defs.IntrinsicDef):
      type_analysis.check_concrete_instance_of(type_spec, value.type_signature)
      return self._strategy.ingest_value(value, type_spec)
    elif isinstance(value, placements.PlacementLiteral):
      if type_spec is None:
        type_spec = computation_types.PlacementType()
      type_spec.check_placement()
      return self._strategy.ingest_value(value, type_spec)
    elif isinstance(value, computation_impl.ConcreteComputation):
      return await self.create_value(
          computation_impl.ConcreteComputation.get_proto(value),
          executor_utils.reconcile_value_with_type_spec(value, type_spec))
    elif isinstance(value, pb.Computation):
      deserialized_type = type_serialization.deserialize_type(value.type)
      if type_spec is None:
        type_spec = deserialized_type
      else:
        type_spec.check_assignable_from(deserialized_type)
      which_computation = value.WhichOneof('computation')
      if which_computation in ['lambda', 'tensorflow', 'xla', 'data']:
        return self._strategy.ingest_value(value, type_spec)
      elif which_computation == 'intrinsic':
        if value.intrinsic.uri in FederatingExecutor._FORWARDED_INTRINSICS:
          return self._strategy.ingest_value(value, type_spec)
        intrinsic_def = intrinsic_defs.uri_to_intrinsic_def(value.intrinsic.uri)
        if intrinsic_def is None:
          raise ValueError('Encountered an unrecognized intrinsic "{}".'.format(
              value.intrinsic.uri))
        return await self.create_value(intrinsic_def, type_spec)
      else:
        raise ValueError(
            'Unsupported computation building block of type "{}".'.format(
                which_computation))
    elif type_spec is not None and type_spec.is_federated():
      return await self._strategy.compute_federated_value(value, type_spec)
    else:
      result = await self._unplaced_executor.create_value(value, type_spec)
      return self._strategy.ingest_value(result, type_spec)
예제 #23
0
def save_functional_model(functional_model: functional.FunctionalModel,
                          path: str):
    """Serializes a `FunctionalModel` as a `tf.SavedModel` to `path`.

  Args:
    functional_model: A `tff.learning.models.FunctionalModel`.
    path: A `str` directory path to serialize the model to.
  """
    m = tf.Module()
    # Serialize the initial_weights values as a tf.function that creates a
    # structure of tensors with the initial weights. This way we can add it to the
    # tf.SavedModel and call it to create initial weights after deserialization.
    create_initial_weights = lambda: functional_model.initial_weights
    with tf.Graph().as_default():
        concrete_structured_fn = tf.function(
            create_initial_weights).get_concrete_function()
    model_weights_tensor_specs = tf.nest.map_structure(
        tf.TensorSpec.from_tensor, concrete_structured_fn.structured_outputs)
    initial_weights_result_type_spec = type_serialization.serialize_type(
        computation_types.to_type(model_weights_tensor_specs))
    m.create_initial_weights_type_spec = tf.Variable(
        initial_weights_result_type_spec.SerializeToString(deterministic=True))

    def flat_initial_weights():
        return tf.nest.flatten(create_initial_weights())

    with tf.Graph().as_default():
        m.create_initial_weights = tf.function(
            flat_initial_weights).get_concrete_function()

    # Serialize forward pass concretely, once for training and once for
    # non-training.
    # TODO(b/198150431): try making `training` a `tf.Tensor` parameter to remove
    # the need to for serializing two different function graphs.
    def make_concrete_flat_forward_pass(training: bool):
        """Create a concrete forward_pass function that has flattened output.

    Args:
      training: A boolean indicating whether this is a call in a training loop,
        or evaluation loop.

    Returns:
      A 2-tuple of concrete `tf.function` instance and a `tff.Type` protocol
      buffer message documenting the the result structure returned by the
      concrete function.
    """
        # Save the un-flattened type spec for deserialization later.
        # Note: `training` is a Python boolean, which gets "curried", in a sense,
        # during function conretization. The resulting concrete function only has
        # parameters for `model_weights` and `batch_input`, which are
        # `tf.TensorSpec` structures here.
        with tf.Graph().as_default():
            concrete_structured_fn = functional_model.forward_pass.get_concrete_function(
                model_weights_tensor_specs,
                functional_model.input_spec,
                # Note: training does not appear in the resulting concrete function.
                training=training)
        output_tensor_spec_structure = tf.nest.map_structure(
            tf.TensorSpec.from_tensor,
            concrete_structured_fn.structured_outputs)
        result_type_spec = type_serialization.serialize_type(
            computation_types.to_type(output_tensor_spec_structure))

        @tf.function
        def flat_forward_pass(model_weights, batch_input, training):
            return tf.nest.flatten(
                functional_model.forward_pass(model_weights, batch_input,
                                              training))

        with tf.Graph().as_default():
            flat_concrete_fn = flat_forward_pass.get_concrete_function(
                model_weights_tensor_specs,
                functional_model.input_spec,
                # Note: training does not appear in the resulting concrete function.
                training=training)
        return flat_concrete_fn, result_type_spec

    fw_pass_training, fw_pass_training_type_spec = make_concrete_flat_forward_pass(
        training=True)
    m.flat_forward_pass_training = fw_pass_training
    m.forward_pass_training_type_spec = tf.Variable(
        fw_pass_training_type_spec.SerializeToString(deterministic=True),
        trainable=False)

    fw_pass_inference, fw_pass_inference_type_spec = make_concrete_flat_forward_pass(
        training=False)
    m.flat_forward_pass_inference = fw_pass_inference
    m.forward_pass_inference_type_spec = tf.Variable(
        fw_pass_inference_type_spec.SerializeToString(deterministic=True),
        trainable=False)

    # Serialize predict_on_batch, once for training, once for non-training.
    x_type = functional_model.input_spec[0]

    # TODO(b/198150431): try making `training` a `tf.Tensor` parameter to remove
    # the need to for serializing two different function graphs.
    def make_concrete_flat_predict_on_batch(training: bool):
        """Create a concrete predict_on_batch function that has flattened output.

    Args:
      training: A boolean indicating whether this is a call in a training loop,
        or evaluation loop.

    Returns:
      A 2-tuple of concrete `tf.function` instance and a `tff.Type` protocol
      buffer message documenting the the result structure returned by the
      concrete
      function.
    """
        # Save the un-flattened type spec for deserialization later.
        # Note: `training` is a Python boolean, which gets "curried", in a sense,
        # during function conretization. The resulting concrete function only has
        # parameters for `model_weights` and `batch_input`, which are
        # `tf.TensorSpec` structures here.
        concrete_structured_fn = tf.function(
            functional_model.predict_on_batch
        ).get_concrete_function(
            model_weights_tensor_specs,
            x_type,
            # Note: training does not appear in the resulting concrete function.
            training=training)
        output_tensor_spec_structure = tf.nest.map_structure(
            tf.TensorSpec.from_tensor,
            concrete_structured_fn.structured_outputs)
        result_type_spec = type_serialization.serialize_type(
            computation_types.to_type(output_tensor_spec_structure))

        @tf.function
        def flat_predict_on_batch(model_weights, x, training):
            return tf.nest.flatten(
                functional_model.predict_on_batch(model_weights, x, training))

        flat_concrete_fn = tf.function(
            flat_predict_on_batch
        ).get_concrete_function(
            model_weights_tensor_specs,
            x_type,
            # Note: training does not appear in the resulting concrete function.
            training=training)
        return flat_concrete_fn, result_type_spec

    with tf.Graph().as_default():
        predict_training, predict_training_type_spec = make_concrete_flat_predict_on_batch(
            training=True)
    m.predict_on_batch_training = predict_training
    m.predict_on_batch_training_type_spec = tf.Variable(
        predict_training_type_spec.SerializeToString(deterministic=True),
        trainable=False)

    with tf.Graph().as_default():
        predict_inference, predict_inference_type_spec = make_concrete_flat_predict_on_batch(
            training=False)
    m.predict_on_batch_inference = predict_inference
    m.predict_on_batch_inference_type_spec = tf.Variable(
        predict_inference_type_spec.SerializeToString(deterministic=True),
        trainable=False)

    # Serialize TFF values as string variables that contain the serialized
    # protos from the computation or the type.
    m.serialized_input_spec = tf.Variable(type_serialization.serialize_type(
        computation_types.to_type(
            functional_model.input_spec)).SerializeToString(
                deterministic=True),
                                          trainable=False)

    # Save everything
    _save_tensorflow_module(m, path)
예제 #24
0
def _extract_intrinsics_to_top_level_lambda(comp, uri):
    r"""Extracts intrinsics in `comp` for the given `uri`.

  This transformation creates an AST such that all the called intrinsics for the
  given `uri` in body of the `building_blocks.Block` returned by the top level
  lambda have been extracted to the top level lambda and replaced by selections
  from a reference to the constructed variable.

                       Lambda
                       |
                       Block
                      /     \
        [x=Struct, ...]       Comp
           |
           [Call,                  Call                   Call]
           /    \                 /    \                 /    \
  Intrinsic      Comp    Intrinsic      Comp    Intrinsic      Comp

  The order of the extracted called intrinsics matches the order of `uri`.

  Note: if this function is passed an AST which contains nested called
  intrinsics, it will fail, as it will mutate the subcomputation containing
  the lower-level called intrinsics on the way back up the tree.

  Args:
    comp: The `building_blocks.Lambda` to transform. The names of lambda
      parameters and block variables in `comp` must be unique.
    uri: A URI of an intrinsic.

  Returns:
    A new computation with the transformation applied or the original `comp`.

  Raises:
    ValueError: If all the intrinsics for the given `uri` in `comp` are not
      exclusively bound by `comp`.
  """
    py_typecheck.check_type(comp, building_blocks.Lambda)
    py_typecheck.check_type(uri, list)
    for x in uri:
        py_typecheck.check_type(x, str)
    tree_analysis.check_has_unique_names(comp)

    name_generator = building_block_factory.unique_name_generator(comp)

    intrinsics = _get_called_intrinsics(comp, uri)
    for intrinsic in intrinsics:
        if not tree_analysis.contains_no_unbound_references(
                intrinsic, comp.parameter_name):
            raise ValueError(
                'Expected a computation which binds all the references in all the '
                'intrinsic with the uri: {}.'.format(uri))
    if len(intrinsics) > 1:
        order = {}
        for index, element in enumerate(uri):
            if element not in order:
                order[element] = index
        intrinsics = sorted(intrinsics, key=lambda x: order[x.function.uri])
        extracted_comp = building_blocks.Struct(intrinsics)
    else:
        extracted_comp = intrinsics[0]
    ref_name = next(name_generator)
    ref_type = computation_types.to_type(extracted_comp.type_signature)
    ref = building_blocks.Reference(ref_name, ref_type)

    def _should_transform(comp):
        return building_block_analysis.is_called_intrinsic(comp, uri)

    def _transform(comp):
        if not _should_transform(comp):
            return comp, False
        if len(intrinsics) > 1:
            index = intrinsics.index(comp)
            comp = building_blocks.Selection(ref, index=index)
            return comp, True
        else:
            return ref, True

    comp, _ = transformation_utils.transform_postorder(comp, _transform)
    comp = _insert_comp_in_top_level_lambda(comp,
                                            name=ref.name,
                                            comp_to_insert=extracted_comp)
    return comp, True
예제 #25
0
def create_whimsy_computation_tensorflow_identity(arg_type=tf.float32):
    """Returns a tensorflow computation and type `(float32 -> float32)`."""
    value, type_signature = tensorflow_computation_factory.create_identity(
        computation_types.to_type(arg_type))
    return value, type_signature
예제 #26
0
def _group_by_intrinsics_in_top_level_lambda(comp):
    """Groups the intrinsics in the frist block local in the result of `comp`.

  This transformation creates an AST by replacing the tuple of called intrinsics
  found as the first local in the `building_blocks.Block` returned by the top
  level lambda with two new computations. The first computation is a tuple of
  tuples of called intrinsics, representing the original tuple of called
  intrinscis grouped by URI. The second computation is a tuple of selection from
  the first computations, representing original tuple of called intrinsics.

  It is necessary to group intrinsics before it is possible to merge them.

  Args:
    comp: The `building_blocks.Lambda` to transform.

  Returns:
    A `building_blocks.Lamda` that returns a `building_blocks.Block`, the first
    local variables of the retunred `building_blocks.Block` will be a tuple of
    tuples of called intrinsics representing the original tuple of called
    intrinscis grouped by URI.

  Raises:
    ValueError: If the first local in the `building_blocks.Block` referenced by
      the top level lambda is not a `building_blocks.Struct` of called
      intrinsics.
  """
    py_typecheck.check_type(comp, building_blocks.Lambda)
    py_typecheck.check_type(comp.result, building_blocks.Block)
    tree_analysis.check_has_unique_names(comp)

    name_generator = building_block_factory.unique_name_generator(comp)

    name, first_local = comp.result.locals[0]
    py_typecheck.check_type(first_local, building_blocks.Struct)
    for element in first_local:
        if not building_block_analysis.is_called_intrinsic(element):
            raise ValueError(
                'Expected all the elements of the `building_blocks.Struct` to be '
                'called intrinsics, but found: \n{}'.format(element))

    # Create collections of data describing how to pack and unpack the intrinsics
    # into groups by their URI.
    #
    # packed_keys is a list of unique URI ordered by occurrence in the original
    #   tuple of called intrinsics.
    # packed_groups is a `collections.OrderedDict` where each key is a URI to
    #   group by and each value is a list of intrinsics with that URI.
    # packed_indexes is a list of tuples where each tuple contains two indexes:
    #   the first index in the tuple is the index of the group that the intrinsic
    #   was packed into; the second index in the tuple is the index of the
    #   intrinsic in that group that the intrinsic was packed into; the index of
    #   the tuple in packed_indexes corresponds to the index of the intrinsic in
    #   the list of intrinsics that are beging grouped. Therefore, packed_indexes
    #   represents an implicit mapping of packed indexes, keyed by unpacked index.
    packed_keys = []
    for called_intrinsic in first_local:
        uri = called_intrinsic.function.uri
        if uri not in packed_keys:
            packed_keys.append(uri)
    # If there are no duplicates, return early.
    if len(packed_keys) == len(first_local):
        return comp, False
    packed_groups = collections.OrderedDict([(x, []) for x in packed_keys])
    packed_indexes = []
    for called_intrinsic in first_local:
        packed_group = packed_groups[called_intrinsic.function.uri]
        packed_group.append(called_intrinsic)
        packed_indexes.append((
            packed_keys.index(called_intrinsic.function.uri),
            len(packed_group) - 1,
        ))

    packed_elements = []
    for called_intrinsics in packed_groups.values():
        if len(called_intrinsics) > 1:
            element = building_blocks.Struct(called_intrinsics)
        else:
            element = called_intrinsics[0]
        packed_elements.append(element)
    packed_comp = building_blocks.Struct(packed_elements)

    packed_ref_name = next(name_generator)
    packed_ref_type = computation_types.to_type(packed_comp.type_signature)
    packed_ref = building_blocks.Reference(packed_ref_name, packed_ref_type)

    unpacked_elements = []
    for indexes in packed_indexes:
        group_index = indexes[0]
        sel = building_blocks.Selection(packed_ref, index=group_index)
        uri = packed_keys[group_index]
        called_intrinsics = packed_groups[uri]
        if len(called_intrinsics) > 1:
            intrinsic_index = indexes[1]
            sel = building_blocks.Selection(sel, index=intrinsic_index)
        unpacked_elements.append(sel)
    unpacked_comp = building_blocks.Struct(unpacked_elements)

    variables = comp.result.locals
    variables[0] = (name, unpacked_comp)
    variables.insert(0, (packed_ref_name, packed_comp))
    block = building_blocks.Block(variables, comp.result.result)
    fn = building_blocks.Lambda(comp.parameter_name, comp.parameter_type,
                                block)
    return fn, True
예제 #27
0
    def test_init_does_not_raise_type_error_with_unknown_dimensions(self):
        server_state_type = computation_types.TensorType(shape=[None],
                                                         dtype=tf.int32)

        @tensorflow_computation.tf_computation
        def initialize():
            # Return a value of a type assignable to, but not equal to
            # `server_state_type`
            return tf.constant([1, 2, 3])

        @tensorflow_computation.tf_computation(server_state_type)
        def prepare(server_state):
            del server_state  # Unused
            return tf.constant(1.0)

        @tensorflow_computation.tf_computation(
            computation_types.SequenceType(tf.float32), tf.float32)
        def work(client_data, client_input):
            del client_data  # Unused
            del client_input  # Unused
            return True, [], [], []

        @tensorflow_computation.tf_computation
        def zero():
            return tf.constant([], dtype=tf.string)

        @tensorflow_computation.tf_computation(
            computation_types.TensorType(shape=[None], dtype=tf.string),
            tf.bool)
        def accumulate(accumulator, client_update):
            del accumulator  # Unused
            del client_update  # Unused
            return tf.constant(['abc'])

        @tensorflow_computation.tf_computation(
            computation_types.TensorType(shape=[None], dtype=tf.string),
            computation_types.TensorType(shape=[None], dtype=tf.string))
        def merge(accumulator1, accumulator2):
            del accumulator1  # Unused
            del accumulator2  # Unused
            return tf.constant(['abc'])

        @tensorflow_computation.tf_computation(
            computation_types.TensorType(shape=[None], dtype=tf.string))
        def report(accumulator):
            del accumulator  # Unused
            return tf.constant(1.0)

        unit_comp = tensorflow_computation.tf_computation(lambda: [])
        bitwidth = unit_comp
        max_input = unit_comp
        modulus = unit_comp
        unit_type = computation_types.to_type([])

        @tensorflow_computation.tf_computation(
            server_state_type, (tf.float32, unit_type, unit_type, unit_type))
        def update(server_state, global_update):
            del server_state  # Unused
            del global_update  # Unused
            # Return a new server state value whose type is assignable but not equal
            # to `server_state_type`, and which is different from the type returned
            # by `initialize`.
            return tf.constant([1]), []

        try:
            forms.MapReduceForm(initialize, prepare, work, zero, accumulate,
                                merge, report, bitwidth, max_input, modulus,
                                update)
        except TypeError:
            self.fail('Raised TypeError unexpectedly.')
예제 #28
0
def pack_args_into_struct(
        args: Sequence[Any],
        kwargs: Mapping[str, Any],
        type_spec=None,
        context: Optional[context_base.Context] = None) -> structure.Struct:
    """Packs positional and keyword arguments into a `Struct`.

  If 'type_spec' is not None, it must be a `StructType` or something that's
  convertible to it by computation_types.to_type(). The assignment of arguments
  to fields of the struct follows the same rule as during function calls. If
  'type_spec' is None, the positional arguments precede any of the keyword
  arguments, and the ordering of the keyword arguments matches the ordering in
  which they appear in kwargs. If the latter is an OrderedDict, the ordering
  will be preserved. On the other hand, if the latter is an ordinary unordered
  dict, the ordering is arbitrary.

  Args:
    args: Positional arguments.
    kwargs: Keyword arguments.
    type_spec: The optional type specification (either an instance of
      `computation_types.StructType` or something convertible to it), or None if
      there's no type. Used to drive the arrangements of args into fields of the
      constructed struct, as noted in the description.
    context: The optional context (an instance of `context_base.Context`) in
      which the arguments are being packed. Required if and only if the
      `type_spec` is not `None`.

  Returns:
    An struct containing all the arguments.

  Raises:
    TypeError: if the arguments are of the wrong computation_types.
  """
    type_spec = computation_types.to_type(type_spec)
    if not type_spec:
        return structure.Struct([(None, arg)
                                 for arg in args] + list(kwargs.items()))
    else:
        py_typecheck.check_type(type_spec, computation_types.StructType)
        py_typecheck.check_type(context, context_base.Context)
        context = typing.cast(context_base.Context, context)
        if not is_argument_struct(type_spec):  # pylint: disable=attribute-error
            raise TypeError(
                'Parameter type {} does not have a structure of an argument struct, '
                'and cannot be populated from multiple positional and keyword '
                'arguments'.format(type_spec))
        else:
            result_elements = []
            positions_used = set()
            keywords_used = set()
            for index, (name, elem_type) in enumerate(
                    structure.to_elements(type_spec)):
                if index < len(args):
                    # This argument is present in `args`.
                    if name is not None and name in kwargs:
                        raise TypeError(
                            'Argument `{}` specified twice.'.format(name))
                    else:
                        arg_value = args[index]
                        result_elements.append(
                            (name, context.ingest(arg_value, elem_type)))
                        positions_used.add(index)
                elif name is not None and name in kwargs:
                    # This argument is present in `kwargs`.
                    arg_value = kwargs[name]
                    result_elements.append(
                        (name, context.ingest(arg_value, elem_type)))
                    keywords_used.add(name)
                elif name:
                    raise TypeError(
                        f'Missing argument `{name}` of type {elem_type}.')
                else:
                    raise TypeError(
                        f'Missing argument of type {elem_type} at position {index}.'
                    )
            positions_missing = set(range(
                len(args))).difference(positions_used)
            if positions_missing:
                raise TypeError(
                    f'Positional arguments at {positions_missing} not used.')
            keywords_missing = set(kwargs.keys()).difference(keywords_used)
            if keywords_missing:
                raise TypeError(
                    f'Keyword arguments at {keywords_missing} not used.')
            return structure.Struct(result_elements)
예제 #29
0
from tensorflow_federated.python.core.impl.tensorflow_context import tensorflow_computation
from tensorflow_federated.python.core.impl.types import computation_types
from tensorflow_federated.python.core.impl.types import placements
from tensorflow_federated.python.core.templates import measured_process
from tensorflow_federated.python.learning import keras_utils
from tensorflow_federated.python.learning import model_examples
from tensorflow_federated.python.learning import model_utils
from tensorflow_federated.python.learning.optimizers import sgdm
from tensorflow_federated.python.learning.templates import client_works
from tensorflow_federated.python.learning.templates import composers
from tensorflow_federated.python.learning.templates import distributors
from tensorflow_federated.python.learning.templates import finalizers
from tensorflow_federated.python.learning.templates import learning_process

FLOAT_TYPE = computation_types.TensorType(tf.float32)
MODEL_WEIGHTS_TYPE = computation_types.to_type(
    model_utils.ModelWeights(FLOAT_TYPE, ()))
CLIENTS_SEQUENCE_FLOAT_TYPE = computation_types.at_clients(
    computation_types.SequenceType(FLOAT_TYPE))


def empty_at_server():
    return intrinsics.federated_value((), placements.SERVER)


@federated_computation.federated_computation()
def empty_init_fn():
    return empty_at_server()


@tensorflow_computation.tf_computation()
def test_init_model_weights_fn():
예제 #30
0
class SecureModularSumFactoryComputationTest(tf.test.TestCase,
                                             parameterized.TestCase):
    @parameterized.named_parameters(
        ('scalar_non_symmetric_int32', 8, tf.int32, False),
        ('scalar_non_symmetric_int64', 8, tf.int64, False),
        ('struct_non_symmetric', 8, _test_struct_type(tf.int32), False),
        ('scalar_symmetric_int32', 8, tf.int32, True),
        ('scalar_symmetric_int64', 8, tf.int64, True),
        ('struct_symmetric', 8, _test_struct_type(tf.int32), True),
        ('numpy_modulus_non_symmetric', np.int32(8), tf.int32, False),
        ('numpy_modulus_symmetric', np.int32(8), tf.int32, True),
    )
    def test_type_properties(self, modulus, value_type, symmetric_range):
        factory_ = secure.SecureModularSumFactory(
            modulus=modulus, symmetric_range=symmetric_range)
        self.assertIsInstance(factory_, factory.UnweightedAggregationFactory)
        value_type = computation_types.to_type(value_type)
        process = factory_.create(value_type)
        self.assertIsInstance(process, aggregation_process.AggregationProcess)

        expected_state_type = computation_types.at_server(
            computation_types.to_type(()))
        expected_measurements_type = expected_state_type

        expected_initialize_type = computation_types.FunctionType(
            parameter=None, result=expected_state_type)
        self.assertTrue(
            process.initialize.type_signature.is_equivalent_to(
                expected_initialize_type))

        expected_next_type = computation_types.FunctionType(
            parameter=collections.OrderedDict(
                state=expected_state_type,
                value=computation_types.at_clients(value_type)),
            result=measured_process.MeasuredProcessOutput(
                state=expected_state_type,
                result=computation_types.at_server(value_type),
                measurements=expected_measurements_type))
        self.assertTrue(
            process.next.type_signature.is_equivalent_to(expected_next_type))
        try:
            static_assert.assert_not_contains_unsecure_aggregation(
                process.next)
        except:  # pylint: disable=bare-except
            self.fail('Factory returned an AggregationProcess containing '
                      'non-secure aggregation.')

    def test_float_modulus_raises(self):
        with self.assertRaises(TypeError):
            secure.SecureModularSumFactory(modulus=8.0)
        with self.assertRaises(TypeError):
            secure.SecureModularSumFactory(modulus=np.float32(8.0))

    def test_modulus_not_positive_raises(self):
        with self.assertRaises(ValueError):
            secure.SecureModularSumFactory(modulus=0)
        with self.assertRaises(ValueError):
            secure.SecureModularSumFactory(modulus=-1)

    def test_symmetric_range_not_bool_raises(self):
        with self.assertRaises(TypeError):
            secure.SecureModularSumFactory(modulus=8, symmetric_range='True')

    @parameterized.named_parameters(
        ('float_type', computation_types.TensorType(tf.float32)),
        ('mixed_type', computation_types.to_type([tf.float32, tf.int32])),
        ('federated_type',
         computation_types.FederatedType(tf.int32, placements.SERVER)),
        ('function_type', computation_types.FunctionType(None, ())),
        ('sequence_type', computation_types.SequenceType(tf.float32)))
    def test_incorrect_value_type_raises(self, bad_value_type):
        with self.assertRaises(TypeError):
            secure.SecureModularSumFactory(8).create(bad_value_type)