def create_scalar_multiply_operator( self, operand_type: computation_types.Type, scalar_type: computation_types.TensorType ) -> local_computation_factory_base.ComputationProtoAndType: py_typecheck.check_type(operand_type, computation_types.Type) py_typecheck.check_type(scalar_type, computation_types.TensorType) if not type_analysis.is_structure_of_tensors(operand_type): raise ValueError( 'Not a tensor or a structure of tensors: {}'.format( str(operand_type))) operand_shapes = _xla_tensor_shape_list_from_from_tff_tensor_or_struct_type( operand_type) scalar_shape = _xla_tensor_shape_from_tff_tensor_type(scalar_type) num_operand_tensors = len(operand_shapes) builder = xla_client.XlaBuilder('comp') param = xla_client.ops.Parameter( builder, 0, xla_client.Shape.tuple_shape(operand_shapes + [scalar_shape])) scalar_ref = xla_client.ops.GetTupleElement(param, num_operand_tensors) result_tensors = [] for idx in range(num_operand_tensors): result_tensors.append( xla_client.ops.Mul(xla_client.ops.GetTupleElement(param, idx), scalar_ref)) xla_client.ops.Tuple(builder, result_tensors) xla_computation = builder.build() comp_type = computation_types.FunctionType( computation_types.StructType([(None, operand_type), (None, scalar_type)]), operand_type) comp_pb = xla_serialization.create_xla_tff_computation( xla_computation, list(range(num_operand_tensors + 1)), comp_type) return (comp_pb, comp_type)
def _apply_generic_op(op, arg): if not (arg.type_signature.is_federated() or type_analysis.is_structure_of_tensors(arg.type_signature)): # If there are federated elements nested in a struct, we need to zip these # together before passing to binary operator constructor. arg = building_block_factory.create_federated_zip(arg) return building_block_factory.apply_binary_operator_with_upcast(arg, op)
def _check_value_type(value_type): """Check value_type meets documented criteria.""" if not (value_type.is_tensor() or (value_type.is_struct_with_python() and type_analysis.is_structure_of_tensors(value_type))): raise TypeError('Expected `value_type` to be `TensorType` or ' '`StructWithPythonType` containing only `TensorType`. ' f'Found type: {repr(value_type)}') if not (type_analysis.is_structure_of_floats(value_type) or type_analysis.is_structure_of_integers(value_type)): raise TypeError('Component dtypes of `value_type` must be all integers or ' f'all floats. Found {value_type}.')
def _validate_value_type_and_encoders(value_type, encoders, encoder_type): """Validates if `value_type` and `encoders` are compatible.""" if isinstance(encoders, _ALLOWED_ENCODERS): # If `encoders` is not a container, then `value_type` should be an instance # of `tff.TensorType.` if not isinstance(value_type, computation_types.TensorType): raise ValueError( '`value_type` and `encoders` do not have the same structure.') _validate_encoder(encoders, value_type, encoder_type) else: # If `encoders` is a container, then `value_type` should be an instance of # `tff.StructType.` if not type_analysis.is_structure_of_tensors(value_type): raise TypeError('`value_type` is not compatible with the expected input ' 'of the `encoders`.') value_tensorspecs = type_conversions.type_to_tf_tensor_specs(value_type) tf.nest.map_structure(lambda e, v: _validate_encoder(e, v, encoder_type), encoders, value_tensorspecs)
def _create_xla_binary_op_computation(type_spec, xla_binary_op_constructor): """Helper for constructing computations that implement binary operators. The constructed computation is of type `(<T,T> -> T)`, where `T` is the type of the operand (`type_spec`). Args: type_spec: The type of a single operand. xla_binary_op_constructor: A two-argument callable that constructs a binary xla op from tensor parameters (such as `xla_client.ops.Add` or similar). Returns: An instance of `local_computation_factory_base.ComputationProtoAndType`. Raises: ValueError: if the arguments are invalid. """ py_typecheck.check_type(type_spec, computation_types.Type) if not type_analysis.is_structure_of_tensors(type_spec): raise ValueError('Not a tensor or a structure of tensors: {}'.format( str(type_spec))) tensor_shapes = _xla_tensor_shape_list_from_from_tff_tensor_or_struct_type( type_spec) num_tensors = len(tensor_shapes) builder = xla_client.XlaBuilder('comp') param = xla_client.ops.Parameter( builder, 0, xla_client.Shape.tuple_shape(tensor_shapes * 2)) result_tensors = [] for idx in range(num_tensors): result_tensors.append( xla_binary_op_constructor( xla_client.ops.GetTupleElement(param, idx), xla_client.ops.GetTupleElement(param, idx + num_tensors))) xla_client.ops.Tuple(builder, result_tensors) xla_computation = builder.build() comp_type = computation_types.FunctionType( computation_types.StructType([(None, type_spec)] * 2), type_spec) comp_pb = xla_serialization.create_xla_tff_computation( xla_computation, list(range(2 * num_tensors)), comp_type) return (comp_pb, comp_type)
def create_constant_from_scalar( self, value, type_spec: computation_types.Type ) -> local_computation_factory_base.ComputationProtoAndType: py_typecheck.check_type(type_spec, computation_types.Type) if not type_analysis.is_structure_of_tensors(type_spec): raise ValueError( 'Not a tensor or a structure of tensors: {}'.format( str(type_spec))) builder = xla_client.XlaBuilder('comp') # We maintain the convention that arguments are supplied as a tuple for the # sake of consistency and uniformity (see comments in `computation.proto`). # Since there are no arguments here, we create an empty tuple. xla_client.ops.Parameter(builder, 0, xla_client.shape_from_pyval(tuple())) def _constant_from_tensor(tensor_type): py_typecheck.check_type(tensor_type, computation_types.TensorType) numpy_value = np.full(shape=tensor_type.shape.dims, fill_value=value, dtype=tensor_type.dtype.as_numpy_dtype) return xla_client.ops.Constant(builder, numpy_value) if isinstance(type_spec, computation_types.TensorType): tensors = [_constant_from_tensor(type_spec)] else: tensors = [ _constant_from_tensor(x) for x in structure.flatten(type_spec) ] # Likewise, results are always returned as a single tuple with results. # This is always a flat tuple; the nested TFF structure is defined by the # binding. xla_client.ops.Tuple(builder, tensors) xla_computation = builder.build() comp_type = computation_types.FunctionType(None, type_spec) comp_pb = xla_serialization.create_xla_tff_computation( xla_computation, [], comp_type) return (comp_pb, comp_type)
def create(self, value_type): # Checks value_type and compute client data dimension. if (value_type.is_struct_with_python() and type_analysis.is_structure_of_tensors(value_type)): num_elements_struct = type_conversions.structure_from_tensor_type_tree( lambda x: x.shape.num_elements(), value_type) self._client_dim = sum(tf.nest.flatten(num_elements_struct)) elif value_type.is_tensor(): self._client_dim = value_type.shape.num_elements() else: raise TypeError( 'Expected `value_type` to be `TensorType` or ' '`StructWithPythonType` containing only `TensorType`. ' f'Found type: {repr(value_type)}') # Checks that all values are integers or floats. if not (type_analysis.is_structure_of_floats(value_type) or type_analysis.is_structure_of_integers(value_type)): raise TypeError( 'Component dtypes of `value_type` must all be integers ' f'or floats. Found {repr(value_type)}.') ddp_agg_process = self._build_aggregation_factory().create(value_type) init_fn = ddp_agg_process.initialize @federated_computation.federated_computation( init_fn.type_signature.result, computation_types.at_clients(value_type)) def next_fn(state, value): agg_output = ddp_agg_process.next(state, value) new_measurements = self._derive_measurements( agg_output.state, agg_output.measurements) new_state = agg_output.state if self._auto_l2_clip: new_state = self._autotune_component_states(agg_output.state) return measured_process.MeasuredProcessOutput( state=new_state, result=agg_output.result, measurements=new_measurements) return aggregation_process.AggregationProcess(init_fn, next_fn)
def create( self, value_type: factory.ValueType ) -> aggregation_process.AggregationProcess: # Checks value_type and compute client data dimension. if (value_type.is_struct() and type_analysis.is_structure_of_tensors(value_type)): num_elements_struct = type_conversions.structure_from_tensor_type_tree( lambda x: x.shape.num_elements(), value_type) client_dim = sum(tf.nest.flatten(num_elements_struct)) elif value_type.is_tensor(): client_dim = value_type.shape.num_elements() else: raise TypeError('Expected `value_type` to be `TensorType` or ' '`StructType` containing only `TensorType`. ' f'Found type: {repr(value_type)}') # Checks that all values are integers. if not type_analysis.is_structure_of_integers(value_type): raise TypeError( 'Component dtypes of `value_type` must all be integers. ' f'Found {repr(value_type)}.') # Checks that we have enough elements to estimate standard deviation. if self._estimate_stddev: if client_dim <= 1: raise ValueError( 'The stddev estimation procedure expects more than ' '1 element from `value_type`. Found `value_type` of ' f'{value_type} with {client_dim} elements.') elif client_dim <= 100: warnings.warn( f'`value_type` has only {client_dim} elements. The ' 'estimated standard deviation may be noisy. Consider ' 'setting `estimate_stddev` to True only if the input ' 'tensor/structure have more than 100 elements.') inner_agg_process = self._inner_agg_factory.create(value_type) init_fn = inner_agg_process.initialize next_fn = self._create_next_fn(inner_agg_process.next, init_fn.type_signature.result, value_type) return aggregation_process.AggregationProcess(init_fn, next_fn)
def from_keras_model( keras_model: tf.keras.Model, loss: Loss, input_spec, loss_weights: Optional[List[float]] = None, metrics: Optional[List[tf.keras.metrics.Metric]] = None ) -> model_lib.Model: """Builds a `tff.learning.Model` from a `tf.keras.Model`. The `tff.learning.Model` returned by this function uses `keras_model` for its forward pass and autodifferentiation steps. Notice that since TFF couples the `tf.keras.Model` and `loss`, TFF needs a slightly different notion of "fully specified type" than pure Keras does. That is, the model `M` takes inputs of type `x` and produces predictions of type `p`; the loss function `L` takes inputs of type `<p, y>` (where `y` is the ground truth label type) and produces a scalar. Therefore in order to fully specify the type signatures for computations in which the generated `tff.learning.Model` will appear, TFF needs the type `y` in addition to the type `x`. Note: This function does not currently accept subclassed `tf.keras.Models`, as it makes assumptions about presence of certain attributes which are guaranteed to exist through the functional or Sequential API but are not necessarily present for subclassed models. Args: keras_model: A `tf.keras.Model` object that is not compiled. loss: A single `tf.keras.losses.Loss` or a list of losses-per-output. If a single loss is provided, then all model output (as well as all prediction information) is passed to the loss; this includes situations of multiple model outputs and/or predictions. If multiple losses are provided as a list, then each loss is expected to correspond to a model output; the model will attempt to minimize the sum of all individual losses (optionally weighted using the `loss_weights` argument). input_spec: A structure of `tf.TensorSpec`s or `tff.Type` specifying the type of arguments the model expects. If `input_spec` is a `tff.Type`, its leaf nodes must be `TensorType`s. Note that `input_spec` must be a compound structure of two elements, specifying both the data fed into the model (x) to generate predictions as well as the expected type of the ground truth (y). If provided as a list, it must be in the order [x, y]. If provided as a dictionary, the keys must explicitly be named `'x'` and `'y'`. loss_weights: (Optional) A list of Python floats used to weight the loss contribution of each model output (when providing a list of losses for the `loss` argument). metrics: (Optional) a list of `tf.keras.metrics.Metric` objects. Returns: A `tff.learning.Model` object. Raises: TypeError: If `keras_model` is not an instance of `tf.keras.Model`, if `loss` is not an instance of `tf.keras.losses.Loss` nor a list of instances of `tf.keras.losses.Loss`, if `input_spec` is a `tff.Type` but the leaf nodes are not `tff.TensorType`s, if `loss_weight` is provided but is not a list of floats, or if `metrics` is provided but is not a list of instances of `tf.keras.metrics.Metric`. ValueError: If `keras_model` was compiled, if `loss` is a list of unequal length to the number of outputs of `keras_model`, if `loss_weights` is specified but `loss` is not a list, if `input_spec` does not contain exactly two elements, or if `input_spec` is a dictionary and does not contain keys `'x'` and `'y'`. """ # Validate `keras_model` py_typecheck.check_type(keras_model, tf.keras.Model) if keras_model._is_compiled: # pylint: disable=protected-access raise ValueError('`keras_model` must not be compiled') # Validate and normalize `loss` and `loss_weights` if not isinstance(loss, list): py_typecheck.check_type(loss, tf.keras.losses.Loss) if loss_weights is not None: raise ValueError( '`loss_weights` cannot be used if `loss` is not a list.') loss = [loss] loss_weights = [1.0] else: if len(loss) != len(keras_model.outputs): raise ValueError( 'If a loss list is provided, `keras_model` must have ' 'equal number of outputs to the losses.\nloss: {}\nof ' 'length: {}.\noutputs: {}\nof length: {}.'.format( loss, len(loss), keras_model.outputs, len(keras_model.outputs))) for loss_fn in loss: py_typecheck.check_type(loss_fn, tf.keras.losses.Loss) if loss_weights is None: loss_weights = [1.0] * len(loss) else: if len(loss) != len(loss_weights): raise ValueError( '`keras_model` must have equal number of losses and loss_weights.' '\nloss: {}\nof length: {}.' '\nloss_weights: {}\nof length: {}.'.format( loss, len(loss), loss_weights, len(loss_weights))) for loss_weight in loss_weights: py_typecheck.check_type(loss_weight, float) if len(input_spec) != 2: raise ValueError( 'The top-level structure in `input_spec` must contain ' 'exactly two top-level elements, as it must specify type ' 'information for both inputs to and predictions from the ' 'model. You passed input spec {}.'.format(input_spec)) if isinstance(input_spec, computation_types.Type): if not type_analysis.is_structure_of_tensors(input_spec): raise TypeError( 'Expected a `tff.Type` with all the leaf nodes being ' '`tff.TensorType`s, found an input spec {}.'.format( input_spec)) input_spec = type_conversions.structure_from_tensor_type_tree( lambda tensor_type: tf.TensorSpec(tensor_type.shape, tensor_type. dtype), input_spec) else: tf.nest.map_structure( lambda s: py_typecheck.check_type(s, tf.TensorSpec, 'input spec member'), input_spec) if isinstance(input_spec, collections.abc.Mapping): if 'x' not in input_spec: raise ValueError( 'The `input_spec` is a collections.abc.Mapping (e.g., a dict), so it ' 'must contain an entry with key `\'x\'`, representing the input(s) ' 'to the Keras model.') if 'y' not in input_spec: raise ValueError( 'The `input_spec` is a collections.abc.Mapping (e.g., a dict), so it ' 'must contain an entry with key `\'y\'`, representing the label(s) ' 'to be used in the Keras loss(es).') if metrics is None: metrics = [] else: py_typecheck.check_type(metrics, list) for metric in metrics: py_typecheck.check_type(metric, tf.keras.metrics.Metric) return _KerasModel(keras_model, input_spec=input_spec, loss_fns=loss, loss_weights=loss_weights, metrics=metrics)
def create(self, value_type): # Validate input args and value_type and parse out the TF dtypes. if value_type.is_tensor(): tf_dtype = value_type.dtype elif (value_type.is_struct_with_python() and type_analysis.is_structure_of_tensors(value_type)): if self._prior_norm_bound: raise TypeError( 'If `prior_norm_bound` is specified, `value_type` must ' f'be `TensorType`. Found type: {repr(value_type)}.') tf_dtype = type_conversions.structure_from_tensor_type_tree( lambda x: x.dtype, value_type) else: raise TypeError( 'Expected `value_type` to be `TensorType` or ' '`StructWithPythonType` containing only `TensorType`. ' f'Found type: {repr(value_type)}') # Check that all values are floats. if not type_analysis.is_structure_of_floats(value_type): raise TypeError( 'Component dtypes of `value_type` must all be floats. ' f'Found {repr(value_type)}.') discretize_fn = _build_discretize_fn(value_type, self._stochastic, self._beta) @tensorflow_computation.tf_computation( discretize_fn.type_signature.result, tf.float32) def undiscretize_fn(value, scale_factor): return _undiscretize_struct(value, scale_factor, tf_dtype) inner_value_type = discretize_fn.type_signature.result inner_agg_process = self._inner_agg_factory.create(inner_value_type) @federated_computation.federated_computation() def init_fn(): state = collections.OrderedDict( scale_factor=intrinsics.federated_value( self._scale_factor, placements.SERVER), prior_norm_bound=intrinsics.federated_value( self._prior_norm_bound, placements.SERVER), inner_agg_process=inner_agg_process.initialize()) return intrinsics.federated_zip(state) @federated_computation.federated_computation( init_fn.type_signature.result, computation_types.at_clients(value_type)) def next_fn(state, value): server_scale_factor = state['scale_factor'] client_scale_factor = intrinsics.federated_broadcast( server_scale_factor) server_prior_norm_bound = state['prior_norm_bound'] prior_norm_bound = intrinsics.federated_broadcast( server_prior_norm_bound) discretized_value = intrinsics.federated_map( discretize_fn, (value, client_scale_factor, prior_norm_bound)) inner_state = state['inner_agg_process'] inner_agg_output = inner_agg_process.next(inner_state, discretized_value) undiscretized_agg_value = intrinsics.federated_map( undiscretize_fn, (inner_agg_output.result, server_scale_factor)) new_state = collections.OrderedDict( scale_factor=server_scale_factor, prior_norm_bound=server_prior_norm_bound, inner_agg_process=inner_agg_output.state) measurements = collections.OrderedDict( discretize=inner_agg_output.measurements) return measured_process.MeasuredProcessOutput( state=intrinsics.federated_zip(new_state), result=undiscretized_agg_value, measurements=intrinsics.federated_zip(measurements)) return aggregation_process.AggregationProcess(init_fn, next_fn)
def create( self, value_type: factory.ValueType) -> aggregation_process.AggregationProcess: # Validate input args and value_type and parse out the TF dtypes. if value_type.is_tensor(): tf_dtype = value_type.dtype elif (value_type.is_struct_with_python() and type_analysis.is_structure_of_tensors(value_type)): tf_dtype = type_conversions.structure_from_tensor_type_tree( lambda x: x.dtype, value_type) else: raise TypeError('Expected `value_type` to be `TensorType` or ' '`StructWithPythonType` containing only `TensorType`. ' f'Found type: {repr(value_type)}') # Check that all values are floats. if not type_analysis.is_structure_of_floats(value_type): raise TypeError('Component dtypes of `value_type` must all be floats. ' f'Found {repr(value_type)}.') if self._distortion_aggregation_factory is not None: distortion_aggregation_process = self._distortion_aggregation_factory.create( computation_types.to_type(tf.float32)) @tensorflow_computation.tf_computation(value_type, tf.float32) def discretize_fn(value, step_size): return _discretize_struct(value, step_size) @tensorflow_computation.tf_computation(discretize_fn.type_signature.result, tf.float32) def undiscretize_fn(value, step_size): return _undiscretize_struct(value, step_size, tf_dtype) @tensorflow_computation.tf_computation(value_type, tf.float32) def distortion_measurement_fn(value, step_size): reconstructed_value = undiscretize_fn( discretize_fn(value, step_size), step_size) err = tf.nest.map_structure(tf.subtract, reconstructed_value, value) squared_err = tf.nest.map_structure(tf.square, err) flat_squared_errs = [ tf.cast(tf.reshape(t, [-1]), tf.float32) for t in tf.nest.flatten(squared_err) ] all_squared_errs = tf.concat(flat_squared_errs, axis=0) mean_squared_err = tf.reduce_mean(all_squared_errs) return mean_squared_err inner_agg_process = self._inner_agg_factory.create( discretize_fn.type_signature.result) @federated_computation.federated_computation() def init_fn(): state = collections.OrderedDict( step_size=intrinsics.federated_value(self._step_size, placements.SERVER), inner_agg_process=inner_agg_process.initialize()) return intrinsics.federated_zip(state) @federated_computation.federated_computation( init_fn.type_signature.result, computation_types.at_clients(value_type)) def next_fn(state, value): server_step_size = state['step_size'] client_step_size = intrinsics.federated_broadcast(server_step_size) discretized_value = intrinsics.federated_map(discretize_fn, (value, client_step_size)) inner_state = state['inner_agg_process'] inner_agg_output = inner_agg_process.next(inner_state, discretized_value) undiscretized_agg_value = intrinsics.federated_map( undiscretize_fn, (inner_agg_output.result, server_step_size)) new_state = collections.OrderedDict( step_size=server_step_size, inner_agg_process=inner_agg_output.state) measurements = collections.OrderedDict( deterministic_discretization=inner_agg_output.measurements) if self._distortion_aggregation_factory is not None: distortions = intrinsics.federated_map(distortion_measurement_fn, (value, client_step_size)) aggregate_distortion = distortion_aggregation_process.next( distortion_aggregation_process.initialize(), distortions).result measurements['distortion'] = aggregate_distortion return measured_process.MeasuredProcessOutput( state=intrinsics.federated_zip(new_state), result=undiscretized_agg_value, measurements=intrinsics.federated_zip(measurements)) return aggregation_process.AggregationProcess(init_fn, next_fn)