def federated_aggregate_keras_metric(metric_type, metric_config, federated_variables): """Aggregates variables a keras metric placed at CLIENTS to SERVER. Args: metric_type: a type object (type must inherit from `tf.keras.metrics.Metric`). metric_config: the result of calling `get_config()` on a metric object, used with `metric_type.from_config()` to locally construct a new metric object. federated_variables: a federated value place on clients that is the value returned by `tf.keras.metrics.Metric.variables`. Returns: The result of calling `result()` on a `tf.keras.metrics.Metric` of type `metric_type`, after aggregation all CLIENTS places `variables`. """ member_type = federated_variables.type_signature.member @tff.tf_computation def zeros_fn(): return anonymous_tuple.map_structure( lambda v: tf.zeros(v.shape, dtype=v.dtype), member_type) zeros = zeros_fn() # TODO(b/123995628): as of 2019-02-01 all variables created in a # `tf.keras.metrics.Metric` use the argument # `aggregation=tf.VariableAggregation.SUM`, hence below only uses `tf.add`. # This may change in the future (and the `tf.Variable.aggregation` property # will be exposed in a future TF version). Need to handle non-SUM variables. @tff.tf_computation(member_type, member_type) def accumulate(accumulators, variables): return anonymous_tuple.map_structure(tf.add, accumulators, variables) @tff.tf_computation(member_type, member_type) def merge(a, b): return anonymous_tuple.map_structure(tf.add, a, b) @tff.tf_computation(member_type) def report(accumulators): """Insert `accumulators` back into the kera metric to obtain result.""" # NOTE: the following call requires that `metric_type` have a no argument # __init__ method, which will restrict the types of metrics that can be # used. This is somewhat limiting, but the pattern to use default arguments # and export the values in `get_config()` (see # `tf.keras.metrics.TopKCategoricalAccuracy`) works well. keras_metric = None try: keras_metric = metric_type.from_config(metric_config) except TypeError as e: # Re-raise the error with a more helpful message, but the previous stack # trace. raise TypeError( 'Caught expection trying to call `{t}.from_config()` with ' 'config {c}. Confirm that {t}.__init__() has an argument for ' 'each member of the config.\nException: {e}'.format( t=metric_type, c=metric_config, e=e)) assignments = [] for v, a in zip(keras_metric.variables, accumulators): assignments.append(tf.assign(v, a)) with tf.control_dependencies(assignments): return keras_metric.result() return tff.federated_aggregate(federated_variables, zeros, accumulate, merge, report)
def federated_aggregate_keras_metric( metrics: Union[tf.keras.metrics.Metric, Sequence[tf.keras.metrics.Metric]], federated_values): """Aggregates variables a keras metric placed at CLIENTS to SERVER. Args: metrics: a single `tf.keras.metrics.Metric` or a `Sequence` of metrics . The order must match the order of variables in `federated_values`. federated_values: a single federated value, or a `Sequence` of federated values. The values must all have `tff.CLIENTS` placement. If value is a `Sequence` type, it must match the order of the sequence in `metrics. Returns: The result of performing a federated sum on federated_values, then assigning the aggregated values into the variables of the corresponding `tf.keras.metrics.Metric` and calling `tf.keras.metrics.Metric.result`. The resulting structure has `tff.SERVER` placement. """ member_types = tf.nest.map_structure(lambda t: t.type_signature.member, federated_values) @tff.tf_computation def zeros_fn(): # `member_type` is a (potentially nested) `tff.StructType`, which is an # `structure.Struct`. return structure.map_structure( lambda v: tf.zeros(v.shape, dtype=v.dtype), member_types) zeros = zeros_fn() @tff.tf_computation(member_types, member_types) def accumulate(accumulators, variables): return tf.nest.map_structure(tf.add, accumulators, variables) @tff.tf_computation(member_types, member_types) def merge(a, b): return tf.nest.map_structure(tf.add, a, b) @tff.tf_computation(member_types) def report(accumulators): """Insert `accumulators` back into the keras metric to obtain result.""" def finalize_metric(metric: tf.keras.metrics.Metric, values): # Note: the following call requires that `type(metric)` have a no argument # __init__ method, which will restrict the types of metrics that can be # used. This is somewhat limiting, but the pattern to use default # arguments and export the values in `get_config()` (see # `tf.keras.metrics.TopKCategoricalAccuracy`) works well. keras_metric = None try: # This is some trickery to reconstruct a metric object in the current # scope, so that the `tf.Variable`s get created when we desire. keras_metric = type(metric).from_config(metric.get_config()) except TypeError as e: # Re-raise the error with a more helpful message, but the previous stack # trace. raise TypeError( 'Caught exception trying to call `{t}.from_config()` with ' 'config {c}. Confirm that {t}.__init__() has an argument for ' 'each member of the config.\nException: {e}'.format( t=type(metric), c=metric.config(), e=e)) assignments = [] for v, a in zip(keras_metric.variables, values): assignments.append(v.assign(a)) with tf.control_dependencies(assignments): return keras_metric.result() if isinstance(metrics, tf.keras.metrics.Metric): # Only a single metric to aggregate. return finalize_metric(metrics, accumulators) else: # Otherwise map over all the metrics. return collections.OrderedDict([ (name, finalize_metric(metric, values)) for metric, (name, values) in zip(metrics, accumulators.items()) ]) return tff.federated_aggregate(federated_values, zeros, accumulate, merge, report)