def create( self, value_type: factory.ValueType) -> aggregation_process.AggregationProcess: _check_value_type(value_type) value_specs = type_conversions.type_to_tf_tensor_specs(value_type) seeds_per_round = self._num_repeats * len(structure.flatten(value_type)) next_global_seed_fn = _build_next_global_seed_fn(stride=seeds_per_round) @tensorflow_computation.tf_computation(value_type, SEED_TFF_TYPE) def client_transform(value, global_seed): @tf.function def transform(tensor, seed): for _ in range(self._num_repeats): tensor *= sample_rademacher(tf.shape(tensor), tensor.dtype, seed) tensor = tf.expand_dims(tensor, axis=0) tensor = hadamard.fast_walsh_hadamard_transform(tensor) tensor = tf.squeeze(tensor, axis=0) seed += 1 return tensor value = _flatten_and_pad_zeros_pow2(value) seeds = _unique_seeds_for_struct( value, global_seed, stride=self._num_repeats) return tf.nest.map_structure(transform, value, seeds) inner_agg_process = self._inner_agg_factory.create( client_transform.type_signature.result) @tensorflow_computation.tf_computation( client_transform.type_signature.result, SEED_TFF_TYPE) def server_transform(value, global_seed): @tf.function def transform(tensor, seed): seed += self._num_repeats - 1 for _ in range(self._num_repeats): tensor = tf.expand_dims(tensor, axis=0) tensor = hadamard.fast_walsh_hadamard_transform(tensor) tensor = tf.squeeze(tensor, axis=0) tensor *= sample_rademacher(tf.shape(tensor), tensor.dtype, seed) seed -= 1 return tensor seeds = _unique_seeds_for_struct( value, global_seed, stride=self._num_repeats) value = tf.nest.map_structure(transform, value, seeds) return tf.nest.map_structure(_slice_and_reshape_to_template_spec, value, value_specs) @federated_computation.federated_computation() def init_fn(): inner_state = inner_agg_process.initialize() my_state = intrinsics.federated_eval( tensorflow_computation.tf_computation(_init_global_seed), placements.SERVER) return intrinsics.federated_zip((inner_state, my_state)) @federated_computation.federated_computation( init_fn.type_signature.result, computation_types.at_clients(value_type)) def next_fn(state, value): next_fn_impl = _build_next_fn(client_transform, inner_agg_process, server_transform, next_global_seed_fn, 'hd') return next_fn_impl(state, value) return aggregation_process.AggregationProcess(init_fn, next_fn)
def __init__(self, initialize_fn, next_fn): super().__init__(initialize_fn, next_fn, next_is_multi_arg=True) if not initialize_fn.type_signature.result.is_federated(): raise errors.TemplateNotFederatedError( f'Provided `initialize_fn` must return a federated type, but found ' f'return type:\n{initialize_fn.type_signature.result}\nTip: If you ' f'see a collection of federated types, try wrapping the returned ' f'value in `tff.federated_zip` before returning.') next_types = (structure.flatten(next_fn.type_signature.parameter) + structure.flatten(next_fn.type_signature.result)) if not all([t.is_federated() for t in next_types]): offending_types = '\n- '.join( [t for t in next_types if not t.is_federated()]) raise errors.TemplateNotFederatedError( f'Provided `next_fn` must be a *federated* computation, that is, ' f'operate on `tff.FederatedType`s, but found\n' f'next_fn with type signature:\n{next_fn.type_signature}\n' f'The non-federated types are:\n {offending_types}.') if initialize_fn.type_signature.result.placement != placements.SERVER: raise errors.TemplatePlacementError( f'The state controlled by an `FinalizerProcess` must be placed at ' f'the SERVER, but found type: {initialize_fn.type_signature.result}.' ) # Note that state of next_fn being placed at SERVER is now ensured by the # assertions in base class which would otherwise raise # TemplateStateNotAssignableError. next_fn_param = next_fn.type_signature.parameter if not next_fn_param.is_struct(): raise errors.TemplateNextFnNumArgsError( f'The `next_fn` must have exactly two input arguments, but found ' f'the following input type which is not a Struct: {next_fn_param}.' ) if len(next_fn_param) != 3: next_param_str = '\n- '.join([str(t) for t in next_fn_param]) raise errors.TemplateNextFnNumArgsError( f'The `next_fn` must have exactly three input arguments, but found ' f'{len(next_fn_param)} input arguments:\n{next_param_str}') model_weights_param = next_fn_param[1] update_from_clients_param = next_fn_param[2] if model_weights_param.placement != placements.SERVER: raise errors.TemplatePlacementError( f'The second input argument of `next_fn` must be placed at SERVER ' f'but found {model_weights_param}.') if update_from_clients_param.placement != placements.SERVER: raise errors.TemplatePlacementError( f'The third input argument of `next_fn` must be placed at SERVER ' f'but found {update_from_clients_param}.') next_fn_result = next_fn.type_signature.result if next_fn_result.result.placement != placements.SERVER: raise errors.TemplatePlacementError( f'The "result" attribute of the return type of `next_fn` must be ' f'placed at SERVER, but found {next_fn_result.result}.') if not model_weights_param.member.is_assignable_from( next_fn_result.result.member): raise FinalizerResultTypeError( f'The second input argument of `next_fn` must match the "result" ' f'attribute of the return type of `next_fn`. Found:\n' f'Second input argument: {next_fn_param[1].member}\n' f'Result attribute: {next_fn_result.result.member}.') if next_fn_result.measurements.placement != placements.SERVER: raise errors.TemplatePlacementError( f'The "measurements" attribute of return type of `next_fn` must be ' f'placed at SERVER, but found {next_fn_result.measurements}.')
def create( self, value_type: factory.ValueType) -> aggregation_process.AggregationProcess: _check_value_type(value_type) value_specs = type_conversions.structure_from_tensor_type_tree( lambda x: tf.TensorSpec(x.shape, x.dtype), value_type) seeds_per_round = self._num_repeats * len(structure.flatten(value_type)) next_global_seed_fn = _build_next_global_seed_fn(stride=seeds_per_round) @tensorflow_computation.tf_computation(value_type, SEED_TFF_TYPE) def client_transform(value, global_seed): @tf.function def transform(tensor, seed): for _ in range(self._num_repeats): tensor = tf.reshape(tensor, [2, -1]) tensor = tf.complex(real=tensor[0], imag=tensor[1]) tensor *= sample_cis(tf.shape(tensor), seed, inverse=False) tensor = tf.signal.fft(tensor) tensor = tf.concat( [tf.math.real(tensor), tf.math.imag(tensor)], axis=0) tensor /= tf.cast(tf.sqrt(tf.size(tensor) / 2), OUTPUT_TF_DTYPE) seed += 1 return tensor value = _flatten_and_pad_zeros_even(value) seeds = _unique_seeds_for_struct( value, global_seed, stride=self._num_repeats) return tf.nest.map_structure(transform, value, seeds) inner_agg_process = self._inner_agg_factory.create( client_transform.type_signature.result) @tensorflow_computation.tf_computation( client_transform.type_signature.result, SEED_TFF_TYPE) def server_transform(value, global_seed): @tf.function def transform(tensor, seed): seed += self._num_repeats - 1 for _ in range(self._num_repeats): tensor *= tf.sqrt(tf.size(tensor, out_type=tensor.dtype) / 2.0) tensor = tf.reshape(tensor, [2, -1]) tensor = tf.complex(real=tensor[0], imag=tensor[1]) tensor = tf.signal.ifft(tensor) tensor *= sample_cis(tf.shape(tensor), seed, inverse=True) tensor = tf.concat( [tf.math.real(tensor), tf.math.imag(tensor)], axis=0) seed -= 1 return tensor seeds = _unique_seeds_for_struct( value, global_seed, stride=self._num_repeats) value = tf.nest.map_structure(transform, value, seeds) return tf.nest.map_structure(_slice_and_reshape_to_template_spec, value, value_specs) @federated_computation.federated_computation() def init_fn(): inner_state = inner_agg_process.initialize() my_state = intrinsics.federated_eval( tensorflow_computation.tf_computation(_init_global_seed), placements.SERVER) return intrinsics.federated_zip((inner_state, my_state)) @federated_computation.federated_computation( init_fn.type_signature.result, computation_types.at_clients(value_type)) def next_fn(state, value): next_fn_impl = _build_next_fn(client_transform, inner_agg_process, server_transform, next_global_seed_fn, 'dft') return next_fn_impl(state, value) return aggregation_process.AggregationProcess(init_fn, next_fn)
def create( self, value_type: factory.ValueType) -> aggregation_process.AggregationProcess: """Creates a `tff.aggregators.AggregationProcess` aggregating `value_type`. The provided `value_type` is a non-federated `tff.Type` object, that is, `value_type.is_federated()` should return `False`. Provided `value_type` must be a `tff.TensorType` or a `tff.StructType`. The returned `tff.aggregators.AggregationProcess` will be created for computation of a weighted mean of values matching `value_type`. That is, its `next` method will expect type `<S@SERVER, {value_type}@CLIENTS, {float32}@CLIENTS>`, where `S` is the unplaced return type of its `initialize` method and all elements of `value_type` must be of floating dtype. Args: value_type: A `tff.Type` without placement. Returns: A `tff.templates.AggregationProcess`. """ py_typecheck.check_type(value_type, factory.ValueType.__args__) if not all([t.dtype.is_floating for t in structure.flatten(value_type)]): raise TypeError(f'All values in provided value_type must be of floating ' f'dtype. Provided value_type: {value_type}') weight_type = computation_types.to_type(tf.float32) value_sum_process = self._value_sum_factory.create(value_type) weight_sum_process = self._weight_sum_factory.create(weight_type) @computations.federated_computation() def init_fn(): state = collections.OrderedDict( value_sum_process=value_sum_process.initialize(), weight_sum_process=weight_sum_process.initialize()) return intrinsics.federated_zip(state) @computations.federated_computation( init_fn.type_signature.result, computation_types.FederatedType(value_type, placements.CLIENTS), computation_types.FederatedType(weight_type, placements.CLIENTS)) def next_fn(state, value, weight): # Client computation. weighted_value = intrinsics.federated_map(_mul, (value, weight)) # Inner aggregations. value_output = value_sum_process.next(state['value_sum_process'], weighted_value) weight_output = weight_sum_process.next(state['weight_sum_process'], weight) # Server computation. if self._no_nan_division: weighted_mean_value = intrinsics.federated_map( _div_no_nan, (value_output.result, weight_output.result)) else: weighted_mean_value = intrinsics.federated_map( _div, (value_output.result, weight_output.result)) # Output preparation. state = collections.OrderedDict( value_sum_process=value_output.state, weight_sum_process=weight_output.state) measurements = collections.OrderedDict( value_sum_process=value_output.measurements, weight_sum_process=weight_output.measurements) return measured_process.MeasuredProcessOutput( intrinsics.federated_zip(state), weighted_mean_value, intrinsics.federated_zip(measurements)) return aggregation_process.AggregationProcess(init_fn, next_fn)
def __init__(self, initialize_fn, next_fn): super().__init__(initialize_fn, next_fn, next_is_multi_arg=True) if not initialize_fn.type_signature.result.is_federated(): raise errors.TemplateNotFederatedError( f'Provided `initialize_fn` must return a federated type, but found ' f'return type:\n{initialize_fn.type_signature.result}\nTip: If you ' f'see a collection of federated types, try wrapping the returned ' f'value in `tff.federated_zip` before returning.') next_types = (structure.flatten(next_fn.type_signature.parameter) + structure.flatten(next_fn.type_signature.result)) if not all([t.is_federated() for t in next_types]): offending_types = '\n- '.join( [t for t in next_types if not t.is_federated()]) raise errors.TemplateNotFederatedError( f'Provided `next_fn` must be a *federated* computation, that is, ' f'operate on `tff.FederatedType`s, but found\n' f'next_fn with type signature:\n{next_fn.type_signature}\n' f'The non-federated types are:\n {offending_types}.') if initialize_fn.type_signature.result.placement != placements.SERVER: raise errors.TemplatePlacementError( f'The state controlled by a `ClientWorkProcess` must be placed at ' f'the SERVER, but found type: {initialize_fn.type_signature.result}.' ) # Note that state of next_fn being placed at SERVER is now ensured by the # assertions in base class which would otherwise raise # TemplateStateNotAssignableError. next_fn_param = next_fn.type_signature.parameter if not next_fn_param.is_struct(): raise errors.TemplateNextFnNumArgsError( f'The `next_fn` must have exactly three input arguments, but found ' f'the following input type which is not a Struct: {next_fn_param}.' ) if len(next_fn_param) != 3: next_param_str = '\n- '.join([str(t) for t in next_fn_param]) raise errors.TemplateNextFnNumArgsError( f'The `next_fn` must have exactly three input arguments, but found ' f'{len(next_fn_param)} input arguments:\n{next_param_str}') second_next_param = next_fn_param[1] client_data_param = next_fn_param[2] if second_next_param.placement != placements.CLIENTS: raise errors.TemplatePlacementError( f'The second input argument of `next_fn` must be placed at CLIENTS ' f'but found {second_next_param}.') if client_data_param.placement != placements.CLIENTS: raise errors.TemplatePlacementError( f'The third input argument of `next_fn` must be placed at CLIENTS ' f'but found {client_data_param}.') def is_allowed_client_data_type( type_spec: computation_types.Type) -> bool: if type_spec.is_sequence(): return type_analysis.is_tensorflow_compatible_type( type_spec.element) elif type_spec.is_struct(): return all( is_allowed_client_data_type(element_type) for element_type in type_spec.children()) else: return False if not is_allowed_client_data_type(client_data_param.member): raise ClientDataTypeError( f'The third input argument of `next_fn` must be a sequence or ' f'a structure of squences, but found {client_data_param}.') next_fn_result = next_fn.type_signature.result if (not next_fn_result.result.is_federated() or next_fn_result.result.placement != placements.CLIENTS): raise errors.TemplatePlacementError( f'The "result" attribute of the return type of `next_fn` must be ' f'placed at CLIENTS, but found {next_fn_result.result}.') if (not next_fn_result.result.member.is_struct_with_python() or next_fn_result.result.member.python_container is not ClientResult): raise ClientResultTypeError( f'The "result" attribute of the return type of `next_fn` must have ' f'the `ClientResult` container, but found {next_fn_result.result}.' ) if next_fn_result.measurements.placement != placements.SERVER: raise errors.TemplatePlacementError( f'The "measurements" attribute of return type of `next_fn` must be ' f'placed at SERVER, but found {next_fn_result.measurements}.')
def create( self, value_type: factory.ValueType ) -> aggregation_process.AggregationProcess: py_typecheck.check_type(value_type, factory.ValueType.__args__) if not all( [t.dtype.is_floating for t in structure.flatten(value_type)]): raise TypeError( f'All values in provided value_type must be of floating ' f'dtype. Provided value_type: {value_type}') inner_agg_process = self._inner_agg_factory.create(value_type) count_type = computation_types.to_type(COUNT_TF_TYPE) clipped_count_agg_process = self._clipped_count_agg_factory.create( count_type) zeroed_count_agg_process = self._zeroed_count_agg_factory.create( count_type) @computations.federated_computation() def init_fn(): return intrinsics.federated_zip( collections.OrderedDict( clipping_norm=self._clipping_norm_process.initialize(), inner_agg=inner_agg_process.initialize(), clipped_count_agg=clipped_count_agg_process.initialize(), zeroed_count_agg=zeroed_count_agg_process.initialize())) @computations.tf_computation(value_type, NORM_TF_TYPE, NORM_TF_TYPE) def clip_and_zero(value, clipping_norm, zeroing_norm): clipped_value_as_list, global_norm = tf.clip_by_global_norm( tf.nest.flatten(value), clipping_norm) clipped_value = tf.nest.pack_sequence_as(value, clipped_value_as_list) was_clipped = tf.cast((global_norm > clipping_norm), COUNT_TF_TYPE) should_zero = (global_norm > zeroing_norm) zeroed_and_clipped = tf.cond( should_zero, lambda: tf.nest.map_structure(tf.zeros_like, value), lambda: clipped_value) was_zeroed = tf.cast(should_zero, COUNT_TF_TYPE) return zeroed_and_clipped, global_norm, was_clipped, was_zeroed @computations.federated_computation( init_fn.type_signature.result, computation_types.at_clients(value_type), computation_types.at_clients(tf.float32)) def next_fn(state, value, weight): (clipping_norm_state, agg_state, clipped_count_state, zeroed_count_state) = state clipping_norm = self._clipping_norm_process.report( clipping_norm_state) zeroing_norm = intrinsics.federated_map(self._zeroing_norm_fn, clipping_norm) (zeroed_and_clipped, global_norm, was_clipped, was_zeroed) = intrinsics.federated_map( clip_and_zero, (value, intrinsics.federated_broadcast(clipping_norm), intrinsics.federated_broadcast(zeroing_norm))) new_clipping_norm_state = self._clipping_norm_process.next( clipping_norm_state, global_norm) agg_output = inner_agg_process.next(agg_state, zeroed_and_clipped, weight) clipped_count_output = clipped_count_agg_process.next( clipped_count_state, was_clipped) zeroed_count_output = zeroed_count_agg_process.next( zeroed_count_state, was_zeroed) new_state = collections.OrderedDict( clipping_norm=new_clipping_norm_state, inner_agg=agg_output.state, clipped_count_agg=clipped_count_output.state, zeroed_count_agg=zeroed_count_output.state) measurements = collections.OrderedDict( agg_process=agg_output.measurements, clipping_norm=clipping_norm, zeroing_norm=zeroing_norm, clipped_count=clipped_count_output.result, zeroed_count=zeroed_count_output.result) return measured_process.MeasuredProcessOutput( state=intrinsics.federated_zip(new_state), result=agg_output.result, measurements=intrinsics.federated_zip(measurements)) return aggregation_process.AggregationProcess(init_fn, next_fn)
def create( self, value_type: factory.ValueType ) -> aggregation_process.AggregationProcess: py_typecheck.check_type(value_type, factory.ValueType.__args__) # This could perhaps be relaxed if we want to zero out ints for example. if not all( [t.dtype.is_floating for t in structure.flatten(value_type)]): raise TypeError( f'All values in provided value_type must be of floating ' f'dtype. Provided value_type: {value_type}') inner_agg_process = self._inner_agg_factory.create(value_type) count_type = computation_types.to_type(COUNT_TF_TYPE) zeroed_count_agg_process = self._zeroed_count_agg_factory.create( count_type) @computations.federated_computation() def init_fn(): return intrinsics.federated_zip( collections.OrderedDict( zeroing_norm=self._zeroing_norm_process.initialize(), inner_agg=inner_agg_process.initialize(), zeroed_count_agg=zeroed_count_agg_process.initialize())) @computations.tf_computation(value_type, NORM_TF_TYPE) def zero(value, zeroing_norm): if self._norm_order == 1.0: norm = _global_l1_norm(value) elif self._norm_order == 2.0: norm = tf.linalg.global_norm(tf.nest.flatten(value)) else: assert self._norm_order is np.inf norm = _global_inf_norm(value) should_zero = (norm > zeroing_norm) zeroed_value = tf.cond( should_zero, lambda: tf.nest.map_structure(tf.zeros_like, value), lambda: value) was_zeroed = tf.cast(should_zero, COUNT_TF_TYPE) return zeroed_value, norm, was_zeroed @computations.federated_computation( init_fn.type_signature.result, computation_types.at_clients(value_type), computation_types.at_clients(tf.float32)) def next_fn(state, value, weight): zeroing_norm_state, agg_state, zeroed_count_state = state zeroing_norm = self._zeroing_norm_process.report( zeroing_norm_state) zeroed_value, norm, was_zeroed = intrinsics.federated_map( zero, (value, intrinsics.federated_broadcast(zeroing_norm))) new_zeroing_norm_state = self._zeroing_norm_process.next( zeroing_norm_state, norm) agg_output = inner_agg_process.next(agg_state, zeroed_value, weight) zeroed_count_output = zeroed_count_agg_process.next( zeroed_count_state, was_zeroed) new_state = collections.OrderedDict( zeroing_norm=new_zeroing_norm_state, inner_agg=agg_output.state, zeroed_count_agg=zeroed_count_output.state) measurements = collections.OrderedDict( agg_process=agg_output.measurements, zeroing_norm=zeroing_norm, zeroed_count=zeroed_count_output.result) return measured_process.MeasuredProcessOutput( state=intrinsics.federated_zip(new_state), result=agg_output.result, measurements=intrinsics.federated_zip(measurements)) return aggregation_process.AggregationProcess(init_fn, next_fn)