def _check_value_type(value_type): type_args = typing.get_args(factory.ValueType) py_typecheck.check_type(value_type, type_args) if not type_analysis.is_structure_of_floats(value_type): raise TypeError( f'All values in provided value_type must be of floating ' f'dtype. Provided value_type: {value_type}')
def _check_value_type(value_type): """Check value_type meets documented criteria.""" if not (value_type.is_tensor() or (value_type.is_struct_with_python() and type_analysis.is_structure_of_tensors(value_type))): raise TypeError('Expected `value_type` to be `TensorType` or ' '`StructWithPythonType` containing only `TensorType`. ' f'Found type: {repr(value_type)}') if not (type_analysis.is_structure_of_floats(value_type) or type_analysis.is_structure_of_integers(value_type)): raise TypeError('Component dtypes of `value_type` must be all integers or ' f'all floats. Found {value_type}.')
def _check_value_type_compatible_with_config_mode(self, value_type): py_typecheck.check_type(value_type, factory.ValueType.__args__) if self._config_mode == _Config.INT: if not type_analysis.is_structure_of_integers(value_type): raise TypeError( f'The `SecureSumFactory` was configured to work with integer ' f'dtypes. All values in provided `value_type` hence must be of ' f'integer dtype. \nProvided value_type: {value_type}') elif self._config_mode == _Config.FLOAT: if not type_analysis.is_structure_of_floats(value_type): raise TypeError( f'The `SecureSumFactory` was configured to work with floating ' f'point dtypes. All values in provided `value_type` hence must be ' f'of floating point dtype. \nProvided value_type: {value_type}') else: raise ValueError(f'Unexpected internal config type: {self._config_mode}')
def create(self, value_type): # Checks value_type and compute client data dimension. if (value_type.is_struct_with_python() and type_analysis.is_structure_of_tensors(value_type)): num_elements_struct = type_conversions.structure_from_tensor_type_tree( lambda x: x.shape.num_elements(), value_type) self._client_dim = sum(tf.nest.flatten(num_elements_struct)) elif value_type.is_tensor(): self._client_dim = value_type.shape.num_elements() else: raise TypeError( 'Expected `value_type` to be `TensorType` or ' '`StructWithPythonType` containing only `TensorType`. ' f'Found type: {repr(value_type)}') # Checks that all values are integers or floats. if not (type_analysis.is_structure_of_floats(value_type) or type_analysis.is_structure_of_integers(value_type)): raise TypeError( 'Component dtypes of `value_type` must all be integers ' f'or floats. Found {repr(value_type)}.') ddp_agg_process = self._build_aggregation_factory().create(value_type) init_fn = ddp_agg_process.initialize @federated_computation.federated_computation( init_fn.type_signature.result, computation_types.at_clients(value_type)) def next_fn(state, value): agg_output = ddp_agg_process.next(state, value) new_measurements = self._derive_measurements( agg_output.state, agg_output.measurements) new_state = agg_output.state if self._auto_l2_clip: new_state = self._autotune_component_states(agg_output.state) return measured_process.MeasuredProcessOutput( state=new_state, result=agg_output.result, measurements=new_measurements) return aggregation_process.AggregationProcess(init_fn, next_fn)
def _check_value_type_compatible_with_config_mode(self, value_type): type_args = typing.get_args(factory.ValueType) py_typecheck.check_type(value_type, type_args) if not _is_structure_of_single_dtype(value_type): raise TypeError( f'Expected a type which is a structure containing the same dtypes, ' f'found {value_type}.') if self._config_mode == _Config.INT: if not type_analysis.is_structure_of_integers(value_type): raise TypeError( f'The `SecureSumFactory` was configured to work with integer ' f'dtypes. All values in provided `value_type` hence must be of ' f'integer dtype. \nProvided value_type: {value_type}') elif self._config_mode == _Config.FLOAT: if not type_analysis.is_structure_of_floats(value_type): raise TypeError( f'The `SecureSumFactory` was configured to work with floating ' f'point dtypes. All values in provided `value_type` hence must be ' f'of floating point dtype. \nProvided value_type: {value_type}' ) else: raise ValueError( f'Unexpected internal config type: {self._config_mode}')
def create(self, value_type): # Validate input args and value_type and parse out the TF dtypes. if value_type.is_tensor(): tf_dtype = value_type.dtype elif (value_type.is_struct_with_python() and type_analysis.is_structure_of_tensors(value_type)): if self._prior_norm_bound: raise TypeError( 'If `prior_norm_bound` is specified, `value_type` must ' f'be `TensorType`. Found type: {repr(value_type)}.') tf_dtype = type_conversions.structure_from_tensor_type_tree( lambda x: x.dtype, value_type) else: raise TypeError( 'Expected `value_type` to be `TensorType` or ' '`StructWithPythonType` containing only `TensorType`. ' f'Found type: {repr(value_type)}') # Check that all values are floats. if not type_analysis.is_structure_of_floats(value_type): raise TypeError( 'Component dtypes of `value_type` must all be floats. ' f'Found {repr(value_type)}.') discretize_fn = _build_discretize_fn(value_type, self._stochastic, self._beta) @tensorflow_computation.tf_computation( discretize_fn.type_signature.result, tf.float32) def undiscretize_fn(value, scale_factor): return _undiscretize_struct(value, scale_factor, tf_dtype) inner_value_type = discretize_fn.type_signature.result inner_agg_process = self._inner_agg_factory.create(inner_value_type) @federated_computation.federated_computation() def init_fn(): state = collections.OrderedDict( scale_factor=intrinsics.federated_value( self._scale_factor, placements.SERVER), prior_norm_bound=intrinsics.federated_value( self._prior_norm_bound, placements.SERVER), inner_agg_process=inner_agg_process.initialize()) return intrinsics.federated_zip(state) @federated_computation.federated_computation( init_fn.type_signature.result, computation_types.at_clients(value_type)) def next_fn(state, value): server_scale_factor = state['scale_factor'] client_scale_factor = intrinsics.federated_broadcast( server_scale_factor) server_prior_norm_bound = state['prior_norm_bound'] prior_norm_bound = intrinsics.federated_broadcast( server_prior_norm_bound) discretized_value = intrinsics.federated_map( discretize_fn, (value, client_scale_factor, prior_norm_bound)) inner_state = state['inner_agg_process'] inner_agg_output = inner_agg_process.next(inner_state, discretized_value) undiscretized_agg_value = intrinsics.federated_map( undiscretize_fn, (inner_agg_output.result, server_scale_factor)) new_state = collections.OrderedDict( scale_factor=server_scale_factor, prior_norm_bound=server_prior_norm_bound, inner_agg_process=inner_agg_output.state) measurements = collections.OrderedDict( discretize=inner_agg_output.measurements) return measured_process.MeasuredProcessOutput( state=intrinsics.federated_zip(new_state), result=undiscretized_agg_value, measurements=intrinsics.federated_zip(measurements)) return aggregation_process.AggregationProcess(init_fn, next_fn)
def create( self, value_type: factory.ValueType) -> aggregation_process.AggregationProcess: # Validate input args and value_type and parse out the TF dtypes. if value_type.is_tensor(): tf_dtype = value_type.dtype elif (value_type.is_struct_with_python() and type_analysis.is_structure_of_tensors(value_type)): tf_dtype = type_conversions.structure_from_tensor_type_tree( lambda x: x.dtype, value_type) else: raise TypeError('Expected `value_type` to be `TensorType` or ' '`StructWithPythonType` containing only `TensorType`. ' f'Found type: {repr(value_type)}') # Check that all values are floats. if not type_analysis.is_structure_of_floats(value_type): raise TypeError('Component dtypes of `value_type` must all be floats. ' f'Found {repr(value_type)}.') if self._distortion_aggregation_factory is not None: distortion_aggregation_process = self._distortion_aggregation_factory.create( computation_types.to_type(tf.float32)) @tensorflow_computation.tf_computation(value_type, tf.float32) def discretize_fn(value, step_size): return _discretize_struct(value, step_size) @tensorflow_computation.tf_computation(discretize_fn.type_signature.result, tf.float32) def undiscretize_fn(value, step_size): return _undiscretize_struct(value, step_size, tf_dtype) @tensorflow_computation.tf_computation(value_type, tf.float32) def distortion_measurement_fn(value, step_size): reconstructed_value = undiscretize_fn( discretize_fn(value, step_size), step_size) err = tf.nest.map_structure(tf.subtract, reconstructed_value, value) squared_err = tf.nest.map_structure(tf.square, err) flat_squared_errs = [ tf.cast(tf.reshape(t, [-1]), tf.float32) for t in tf.nest.flatten(squared_err) ] all_squared_errs = tf.concat(flat_squared_errs, axis=0) mean_squared_err = tf.reduce_mean(all_squared_errs) return mean_squared_err inner_agg_process = self._inner_agg_factory.create( discretize_fn.type_signature.result) @federated_computation.federated_computation() def init_fn(): state = collections.OrderedDict( step_size=intrinsics.federated_value(self._step_size, placements.SERVER), inner_agg_process=inner_agg_process.initialize()) return intrinsics.federated_zip(state) @federated_computation.federated_computation( init_fn.type_signature.result, computation_types.at_clients(value_type)) def next_fn(state, value): server_step_size = state['step_size'] client_step_size = intrinsics.federated_broadcast(server_step_size) discretized_value = intrinsics.federated_map(discretize_fn, (value, client_step_size)) inner_state = state['inner_agg_process'] inner_agg_output = inner_agg_process.next(inner_state, discretized_value) undiscretized_agg_value = intrinsics.federated_map( undiscretize_fn, (inner_agg_output.result, server_step_size)) new_state = collections.OrderedDict( step_size=server_step_size, inner_agg_process=inner_agg_output.state) measurements = collections.OrderedDict( deterministic_discretization=inner_agg_output.measurements) if self._distortion_aggregation_factory is not None: distortions = intrinsics.federated_map(distortion_measurement_fn, (value, client_step_size)) aggregate_distortion = distortion_aggregation_process.next( distortion_aggregation_process.initialize(), distortions).result measurements['distortion'] = aggregate_distortion return measured_process.MeasuredProcessOutput( state=intrinsics.federated_zip(new_state), result=undiscretized_agg_value, measurements=intrinsics.federated_zip(measurements)) return aggregation_process.AggregationProcess(init_fn, next_fn)
def test_returns_false(self, type_spec): self.assertFalse(type_analysis.is_structure_of_floats(type_spec))