예제 #1
0
  def create(
      self,
      value_type: factory.ValueType) -> aggregation_process.AggregationProcess:
    _check_value_type(value_type)
    value_specs = type_conversions.type_to_tf_tensor_specs(value_type)
    seeds_per_round = self._num_repeats * len(structure.flatten(value_type))
    next_global_seed_fn = _build_next_global_seed_fn(stride=seeds_per_round)

    @tensorflow_computation.tf_computation(value_type, SEED_TFF_TYPE)
    def client_transform(value, global_seed):

      @tf.function
      def transform(tensor, seed):
        for _ in range(self._num_repeats):
          tensor *= sample_rademacher(tf.shape(tensor), tensor.dtype, seed)
          tensor = tf.expand_dims(tensor, axis=0)
          tensor = hadamard.fast_walsh_hadamard_transform(tensor)
          tensor = tf.squeeze(tensor, axis=0)
          seed += 1
        return tensor

      value = _flatten_and_pad_zeros_pow2(value)
      seeds = _unique_seeds_for_struct(
          value, global_seed, stride=self._num_repeats)
      return tf.nest.map_structure(transform, value, seeds)

    inner_agg_process = self._inner_agg_factory.create(
        client_transform.type_signature.result)

    @tensorflow_computation.tf_computation(
        client_transform.type_signature.result, SEED_TFF_TYPE)
    def server_transform(value, global_seed):

      @tf.function
      def transform(tensor, seed):
        seed += self._num_repeats - 1
        for _ in range(self._num_repeats):
          tensor = tf.expand_dims(tensor, axis=0)
          tensor = hadamard.fast_walsh_hadamard_transform(tensor)
          tensor = tf.squeeze(tensor, axis=0)
          tensor *= sample_rademacher(tf.shape(tensor), tensor.dtype, seed)
          seed -= 1
        return tensor

      seeds = _unique_seeds_for_struct(
          value, global_seed, stride=self._num_repeats)
      value = tf.nest.map_structure(transform, value, seeds)
      return tf.nest.map_structure(_slice_and_reshape_to_template_spec, value,
                                   value_specs)

    @federated_computation.federated_computation()
    def init_fn():
      inner_state = inner_agg_process.initialize()
      my_state = intrinsics.federated_eval(
          tensorflow_computation.tf_computation(_init_global_seed),
          placements.SERVER)
      return intrinsics.federated_zip((inner_state, my_state))

    @federated_computation.federated_computation(
        init_fn.type_signature.result, computation_types.at_clients(value_type))
    def next_fn(state, value):
      next_fn_impl = _build_next_fn(client_transform, inner_agg_process,
                                    server_transform, next_global_seed_fn, 'hd')
      return next_fn_impl(state, value)

    return aggregation_process.AggregationProcess(init_fn, next_fn)
예제 #2
0
    def __init__(self, initialize_fn, next_fn):
        super().__init__(initialize_fn, next_fn, next_is_multi_arg=True)

        if not initialize_fn.type_signature.result.is_federated():
            raise errors.TemplateNotFederatedError(
                f'Provided `initialize_fn` must return a federated type, but found '
                f'return type:\n{initialize_fn.type_signature.result}\nTip: If you '
                f'see a collection of federated types, try wrapping the returned '
                f'value in `tff.federated_zip` before returning.')
        next_types = (structure.flatten(next_fn.type_signature.parameter) +
                      structure.flatten(next_fn.type_signature.result))
        if not all([t.is_federated() for t in next_types]):
            offending_types = '\n- '.join(
                [t for t in next_types if not t.is_federated()])
            raise errors.TemplateNotFederatedError(
                f'Provided `next_fn` must be a *federated* computation, that is, '
                f'operate on `tff.FederatedType`s, but found\n'
                f'next_fn with type signature:\n{next_fn.type_signature}\n'
                f'The non-federated types are:\n {offending_types}.')

        if initialize_fn.type_signature.result.placement != placements.SERVER:
            raise errors.TemplatePlacementError(
                f'The state controlled by an `FinalizerProcess` must be placed at '
                f'the SERVER, but found type: {initialize_fn.type_signature.result}.'
            )
        # Note that state of next_fn being placed at SERVER is now ensured by the
        # assertions in base class which would otherwise raise
        # TemplateStateNotAssignableError.

        next_fn_param = next_fn.type_signature.parameter
        if not next_fn_param.is_struct():
            raise errors.TemplateNextFnNumArgsError(
                f'The `next_fn` must have exactly two input arguments, but found '
                f'the following input type which is not a Struct: {next_fn_param}.'
            )
        if len(next_fn_param) != 3:
            next_param_str = '\n- '.join([str(t) for t in next_fn_param])
            raise errors.TemplateNextFnNumArgsError(
                f'The `next_fn` must have exactly three input arguments, but found '
                f'{len(next_fn_param)} input arguments:\n{next_param_str}')
        model_weights_param = next_fn_param[1]
        update_from_clients_param = next_fn_param[2]
        if model_weights_param.placement != placements.SERVER:
            raise errors.TemplatePlacementError(
                f'The second input argument of `next_fn` must be placed at SERVER '
                f'but found {model_weights_param}.')
        if update_from_clients_param.placement != placements.SERVER:
            raise errors.TemplatePlacementError(
                f'The third input argument of `next_fn` must be placed at SERVER '
                f'but found {update_from_clients_param}.')

        next_fn_result = next_fn.type_signature.result
        if next_fn_result.result.placement != placements.SERVER:
            raise errors.TemplatePlacementError(
                f'The "result" attribute of the return type of `next_fn` must be '
                f'placed at SERVER, but found {next_fn_result.result}.')
        if not model_weights_param.member.is_assignable_from(
                next_fn_result.result.member):
            raise FinalizerResultTypeError(
                f'The second input argument of `next_fn` must match the "result" '
                f'attribute of the return type of `next_fn`. Found:\n'
                f'Second input argument: {next_fn_param[1].member}\n'
                f'Result attribute: {next_fn_result.result.member}.')
        if next_fn_result.measurements.placement != placements.SERVER:
            raise errors.TemplatePlacementError(
                f'The "measurements" attribute of return type of `next_fn` must be '
                f'placed at SERVER, but found {next_fn_result.measurements}.')
예제 #3
0
  def create(
      self,
      value_type: factory.ValueType) -> aggregation_process.AggregationProcess:
    _check_value_type(value_type)
    value_specs = type_conversions.structure_from_tensor_type_tree(
        lambda x: tf.TensorSpec(x.shape, x.dtype), value_type)
    seeds_per_round = self._num_repeats * len(structure.flatten(value_type))
    next_global_seed_fn = _build_next_global_seed_fn(stride=seeds_per_round)

    @tensorflow_computation.tf_computation(value_type, SEED_TFF_TYPE)
    def client_transform(value, global_seed):

      @tf.function
      def transform(tensor, seed):
        for _ in range(self._num_repeats):
          tensor = tf.reshape(tensor, [2, -1])
          tensor = tf.complex(real=tensor[0], imag=tensor[1])
          tensor *= sample_cis(tf.shape(tensor), seed, inverse=False)
          tensor = tf.signal.fft(tensor)
          tensor = tf.concat(
              [tf.math.real(tensor), tf.math.imag(tensor)], axis=0)
          tensor /= tf.cast(tf.sqrt(tf.size(tensor) / 2), OUTPUT_TF_DTYPE)
          seed += 1
        return tensor

      value = _flatten_and_pad_zeros_even(value)
      seeds = _unique_seeds_for_struct(
          value, global_seed, stride=self._num_repeats)
      return tf.nest.map_structure(transform, value, seeds)

    inner_agg_process = self._inner_agg_factory.create(
        client_transform.type_signature.result)

    @tensorflow_computation.tf_computation(
        client_transform.type_signature.result, SEED_TFF_TYPE)
    def server_transform(value, global_seed):

      @tf.function
      def transform(tensor, seed):
        seed += self._num_repeats - 1
        for _ in range(self._num_repeats):
          tensor *= tf.sqrt(tf.size(tensor, out_type=tensor.dtype) / 2.0)
          tensor = tf.reshape(tensor, [2, -1])
          tensor = tf.complex(real=tensor[0], imag=tensor[1])
          tensor = tf.signal.ifft(tensor)
          tensor *= sample_cis(tf.shape(tensor), seed, inverse=True)
          tensor = tf.concat(
              [tf.math.real(tensor), tf.math.imag(tensor)], axis=0)
          seed -= 1
        return tensor

      seeds = _unique_seeds_for_struct(
          value, global_seed, stride=self._num_repeats)
      value = tf.nest.map_structure(transform, value, seeds)
      return tf.nest.map_structure(_slice_and_reshape_to_template_spec, value,
                                   value_specs)

    @federated_computation.federated_computation()
    def init_fn():
      inner_state = inner_agg_process.initialize()
      my_state = intrinsics.federated_eval(
          tensorflow_computation.tf_computation(_init_global_seed),
          placements.SERVER)
      return intrinsics.federated_zip((inner_state, my_state))

    @federated_computation.federated_computation(
        init_fn.type_signature.result, computation_types.at_clients(value_type))
    def next_fn(state, value):
      next_fn_impl = _build_next_fn(client_transform, inner_agg_process,
                                    server_transform, next_global_seed_fn,
                                    'dft')
      return next_fn_impl(state, value)

    return aggregation_process.AggregationProcess(init_fn, next_fn)
예제 #4
0
  def create(
      self,
      value_type: factory.ValueType) -> aggregation_process.AggregationProcess:
    """Creates a `tff.aggregators.AggregationProcess` aggregating `value_type`.

    The provided `value_type` is a non-federated `tff.Type` object, that is,
    `value_type.is_federated()` should return `False`. Provided `value_type`
    must be a `tff.TensorType` or a `tff.StructType`.

    The returned `tff.aggregators.AggregationProcess` will be created for
    computation of a weighted mean of values matching `value_type`. That is, its
    `next` method will expect type
    `<S@SERVER, {value_type}@CLIENTS, {float32}@CLIENTS>`, where `S` is the
    unplaced return type of its `initialize` method and all elements of
    `value_type` must be of floating dtype.

    Args:
      value_type: A `tff.Type` without placement.

    Returns:
      A `tff.templates.AggregationProcess`.
    """

    py_typecheck.check_type(value_type, factory.ValueType.__args__)

    if not all([t.dtype.is_floating for t in structure.flatten(value_type)]):
      raise TypeError(f'All values in provided value_type must be of floating '
                      f'dtype. Provided value_type: {value_type}')

    weight_type = computation_types.to_type(tf.float32)
    value_sum_process = self._value_sum_factory.create(value_type)
    weight_sum_process = self._weight_sum_factory.create(weight_type)

    @computations.federated_computation()
    def init_fn():
      state = collections.OrderedDict(
          value_sum_process=value_sum_process.initialize(),
          weight_sum_process=weight_sum_process.initialize())
      return intrinsics.federated_zip(state)

    @computations.federated_computation(
        init_fn.type_signature.result,
        computation_types.FederatedType(value_type, placements.CLIENTS),
        computation_types.FederatedType(weight_type, placements.CLIENTS))
    def next_fn(state, value, weight):
      # Client computation.
      weighted_value = intrinsics.federated_map(_mul, (value, weight))

      # Inner aggregations.
      value_output = value_sum_process.next(state['value_sum_process'],
                                            weighted_value)
      weight_output = weight_sum_process.next(state['weight_sum_process'],
                                              weight)

      # Server computation.
      if self._no_nan_division:
        weighted_mean_value = intrinsics.federated_map(
            _div_no_nan, (value_output.result, weight_output.result))
      else:
        weighted_mean_value = intrinsics.federated_map(
            _div, (value_output.result, weight_output.result))

      # Output preparation.
      state = collections.OrderedDict(
          value_sum_process=value_output.state,
          weight_sum_process=weight_output.state)
      measurements = collections.OrderedDict(
          value_sum_process=value_output.measurements,
          weight_sum_process=weight_output.measurements)
      return measured_process.MeasuredProcessOutput(
          intrinsics.federated_zip(state), weighted_mean_value,
          intrinsics.federated_zip(measurements))

    return aggregation_process.AggregationProcess(init_fn, next_fn)
예제 #5
0
    def __init__(self, initialize_fn, next_fn):
        super().__init__(initialize_fn, next_fn, next_is_multi_arg=True)

        if not initialize_fn.type_signature.result.is_federated():
            raise errors.TemplateNotFederatedError(
                f'Provided `initialize_fn` must return a federated type, but found '
                f'return type:\n{initialize_fn.type_signature.result}\nTip: If you '
                f'see a collection of federated types, try wrapping the returned '
                f'value in `tff.federated_zip` before returning.')
        next_types = (structure.flatten(next_fn.type_signature.parameter) +
                      structure.flatten(next_fn.type_signature.result))
        if not all([t.is_federated() for t in next_types]):
            offending_types = '\n- '.join(
                [t for t in next_types if not t.is_federated()])
            raise errors.TemplateNotFederatedError(
                f'Provided `next_fn` must be a *federated* computation, that is, '
                f'operate on `tff.FederatedType`s, but found\n'
                f'next_fn with type signature:\n{next_fn.type_signature}\n'
                f'The non-federated types are:\n {offending_types}.')

        if initialize_fn.type_signature.result.placement != placements.SERVER:
            raise errors.TemplatePlacementError(
                f'The state controlled by a `ClientWorkProcess` must be placed at '
                f'the SERVER, but found type: {initialize_fn.type_signature.result}.'
            )
        # Note that state of next_fn being placed at SERVER is now ensured by the
        # assertions in base class which would otherwise raise
        # TemplateStateNotAssignableError.

        next_fn_param = next_fn.type_signature.parameter
        if not next_fn_param.is_struct():
            raise errors.TemplateNextFnNumArgsError(
                f'The `next_fn` must have exactly three input arguments, but found '
                f'the following input type which is not a Struct: {next_fn_param}.'
            )
        if len(next_fn_param) != 3:
            next_param_str = '\n- '.join([str(t) for t in next_fn_param])
            raise errors.TemplateNextFnNumArgsError(
                f'The `next_fn` must have exactly three input arguments, but found '
                f'{len(next_fn_param)} input arguments:\n{next_param_str}')
        second_next_param = next_fn_param[1]
        client_data_param = next_fn_param[2]
        if second_next_param.placement != placements.CLIENTS:
            raise errors.TemplatePlacementError(
                f'The second input argument of `next_fn` must be placed at CLIENTS '
                f'but found {second_next_param}.')
        if client_data_param.placement != placements.CLIENTS:
            raise errors.TemplatePlacementError(
                f'The third input argument of `next_fn` must be placed at CLIENTS '
                f'but found {client_data_param}.')

        def is_allowed_client_data_type(
                type_spec: computation_types.Type) -> bool:
            if type_spec.is_sequence():
                return type_analysis.is_tensorflow_compatible_type(
                    type_spec.element)
            elif type_spec.is_struct():
                return all(
                    is_allowed_client_data_type(element_type)
                    for element_type in type_spec.children())
            else:
                return False

        if not is_allowed_client_data_type(client_data_param.member):
            raise ClientDataTypeError(
                f'The third input argument of `next_fn` must be a sequence or '
                f'a structure of squences, but found {client_data_param}.')

        next_fn_result = next_fn.type_signature.result
        if (not next_fn_result.result.is_federated()
                or next_fn_result.result.placement != placements.CLIENTS):
            raise errors.TemplatePlacementError(
                f'The "result" attribute of the return type of `next_fn` must be '
                f'placed at CLIENTS, but found {next_fn_result.result}.')
        if (not next_fn_result.result.member.is_struct_with_python()
                or next_fn_result.result.member.python_container
                is not ClientResult):
            raise ClientResultTypeError(
                f'The "result" attribute of the return type of `next_fn` must have '
                f'the `ClientResult` container, but found {next_fn_result.result}.'
            )
        if next_fn_result.measurements.placement != placements.SERVER:
            raise errors.TemplatePlacementError(
                f'The "measurements" attribute of return type of `next_fn` must be '
                f'placed at SERVER, but found {next_fn_result.measurements}.')
예제 #6
0
    def create(
        self, value_type: factory.ValueType
    ) -> aggregation_process.AggregationProcess:

        py_typecheck.check_type(value_type, factory.ValueType.__args__)

        if not all(
            [t.dtype.is_floating for t in structure.flatten(value_type)]):
            raise TypeError(
                f'All values in provided value_type must be of floating '
                f'dtype. Provided value_type: {value_type}')

        inner_agg_process = self._inner_agg_factory.create(value_type)

        count_type = computation_types.to_type(COUNT_TF_TYPE)
        clipped_count_agg_process = self._clipped_count_agg_factory.create(
            count_type)
        zeroed_count_agg_process = self._zeroed_count_agg_factory.create(
            count_type)

        @computations.federated_computation()
        def init_fn():
            return intrinsics.federated_zip(
                collections.OrderedDict(
                    clipping_norm=self._clipping_norm_process.initialize(),
                    inner_agg=inner_agg_process.initialize(),
                    clipped_count_agg=clipped_count_agg_process.initialize(),
                    zeroed_count_agg=zeroed_count_agg_process.initialize()))

        @computations.tf_computation(value_type, NORM_TF_TYPE, NORM_TF_TYPE)
        def clip_and_zero(value, clipping_norm, zeroing_norm):
            clipped_value_as_list, global_norm = tf.clip_by_global_norm(
                tf.nest.flatten(value), clipping_norm)
            clipped_value = tf.nest.pack_sequence_as(value,
                                                     clipped_value_as_list)
            was_clipped = tf.cast((global_norm > clipping_norm), COUNT_TF_TYPE)
            should_zero = (global_norm > zeroing_norm)
            zeroed_and_clipped = tf.cond(
                should_zero,
                lambda: tf.nest.map_structure(tf.zeros_like, value),
                lambda: clipped_value)
            was_zeroed = tf.cast(should_zero, COUNT_TF_TYPE)
            return zeroed_and_clipped, global_norm, was_clipped, was_zeroed

        @computations.federated_computation(
            init_fn.type_signature.result,
            computation_types.at_clients(value_type),
            computation_types.at_clients(tf.float32))
        def next_fn(state, value, weight):
            (clipping_norm_state, agg_state, clipped_count_state,
             zeroed_count_state) = state

            clipping_norm = self._clipping_norm_process.report(
                clipping_norm_state)
            zeroing_norm = intrinsics.federated_map(self._zeroing_norm_fn,
                                                    clipping_norm)

            (zeroed_and_clipped, global_norm,
             was_clipped, was_zeroed) = intrinsics.federated_map(
                 clip_and_zero,
                 (value, intrinsics.federated_broadcast(clipping_norm),
                  intrinsics.federated_broadcast(zeroing_norm)))

            new_clipping_norm_state = self._clipping_norm_process.next(
                clipping_norm_state, global_norm)
            agg_output = inner_agg_process.next(agg_state, zeroed_and_clipped,
                                                weight)
            clipped_count_output = clipped_count_agg_process.next(
                clipped_count_state, was_clipped)
            zeroed_count_output = zeroed_count_agg_process.next(
                zeroed_count_state, was_zeroed)

            new_state = collections.OrderedDict(
                clipping_norm=new_clipping_norm_state,
                inner_agg=agg_output.state,
                clipped_count_agg=clipped_count_output.state,
                zeroed_count_agg=zeroed_count_output.state)
            measurements = collections.OrderedDict(
                agg_process=agg_output.measurements,
                clipping_norm=clipping_norm,
                zeroing_norm=zeroing_norm,
                clipped_count=clipped_count_output.result,
                zeroed_count=zeroed_count_output.result)

            return measured_process.MeasuredProcessOutput(
                state=intrinsics.federated_zip(new_state),
                result=agg_output.result,
                measurements=intrinsics.federated_zip(measurements))

        return aggregation_process.AggregationProcess(init_fn, next_fn)
예제 #7
0
    def create(
        self, value_type: factory.ValueType
    ) -> aggregation_process.AggregationProcess:

        py_typecheck.check_type(value_type, factory.ValueType.__args__)

        # This could perhaps be relaxed if we want to zero out ints for example.
        if not all(
            [t.dtype.is_floating for t in structure.flatten(value_type)]):
            raise TypeError(
                f'All values in provided value_type must be of floating '
                f'dtype. Provided value_type: {value_type}')

        inner_agg_process = self._inner_agg_factory.create(value_type)

        count_type = computation_types.to_type(COUNT_TF_TYPE)
        zeroed_count_agg_process = self._zeroed_count_agg_factory.create(
            count_type)

        @computations.federated_computation()
        def init_fn():
            return intrinsics.federated_zip(
                collections.OrderedDict(
                    zeroing_norm=self._zeroing_norm_process.initialize(),
                    inner_agg=inner_agg_process.initialize(),
                    zeroed_count_agg=zeroed_count_agg_process.initialize()))

        @computations.tf_computation(value_type, NORM_TF_TYPE)
        def zero(value, zeroing_norm):
            if self._norm_order == 1.0:
                norm = _global_l1_norm(value)
            elif self._norm_order == 2.0:
                norm = tf.linalg.global_norm(tf.nest.flatten(value))
            else:
                assert self._norm_order is np.inf
                norm = _global_inf_norm(value)
            should_zero = (norm > zeroing_norm)
            zeroed_value = tf.cond(
                should_zero,
                lambda: tf.nest.map_structure(tf.zeros_like, value),
                lambda: value)
            was_zeroed = tf.cast(should_zero, COUNT_TF_TYPE)
            return zeroed_value, norm, was_zeroed

        @computations.federated_computation(
            init_fn.type_signature.result,
            computation_types.at_clients(value_type),
            computation_types.at_clients(tf.float32))
        def next_fn(state, value, weight):
            zeroing_norm_state, agg_state, zeroed_count_state = state

            zeroing_norm = self._zeroing_norm_process.report(
                zeroing_norm_state)

            zeroed_value, norm, was_zeroed = intrinsics.federated_map(
                zero, (value, intrinsics.federated_broadcast(zeroing_norm)))

            new_zeroing_norm_state = self._zeroing_norm_process.next(
                zeroing_norm_state, norm)
            agg_output = inner_agg_process.next(agg_state, zeroed_value,
                                                weight)
            zeroed_count_output = zeroed_count_agg_process.next(
                zeroed_count_state, was_zeroed)

            new_state = collections.OrderedDict(
                zeroing_norm=new_zeroing_norm_state,
                inner_agg=agg_output.state,
                zeroed_count_agg=zeroed_count_output.state)
            measurements = collections.OrderedDict(
                agg_process=agg_output.measurements,
                zeroing_norm=zeroing_norm,
                zeroed_count=zeroed_count_output.result)

            return measured_process.MeasuredProcessOutput(
                state=intrinsics.federated_zip(new_state),
                result=agg_output.result,
                measurements=intrinsics.federated_zip(measurements))

        return aggregation_process.AggregationProcess(init_fn, next_fn)