def __init__(self, keras_model: tf.keras.Model, input_spec, loss_fns: List[tf.keras.losses.Loss], loss_weights: List[float], metrics: List[tf.keras.metrics.Metric]): self._keras_model = keras_model self._input_spec = input_spec self._loss_fns = loss_fns self._loss_weights = loss_weights self._metrics = metrics # This is defined here so that it closes over the `loss_fn`. class _WeightedMeanLossMetric(tf.keras.metrics.Mean): """A `tf.keras.metrics.Metric` wrapper for the loss function.""" def __init__(self, name='loss', dtype=tf.float32): super().__init__(name, dtype) self._loss_fns = loss_fns self._loss_weights = loss_weights def update_state(self, y_true, y_pred, sample_weight=None): if len(self._loss_fns) == 1: batch_size = tf.shape(y_pred)[0] batch_loss = self._loss_fns[0](y_true, y_pred) else: batch_size = tf.shape(y_pred[0])[0] batch_loss = tf.zeros(()) for i in range(len(self._loss_fns)): batch_loss += self._loss_weights[i] * self._loss_fns[ i](y_true[i], y_pred[i]) return super().update_state(batch_loss, batch_size) self._loss_metric = _WeightedMeanLossMetric() metric_variable_type_dict = tf.nest.map_structure( tf.TensorSpec.from_tensor, self.report_local_outputs()) federated_local_outputs_type = tff.FederatedType( metric_variable_type_dict, tff.CLIENTS) def federated_output(local_outputs): return federated_aggregate_keras_metric(self.get_metrics(), local_outputs) self._federated_output_computation = tff.federated_computation( federated_output, federated_local_outputs_type)
def __init__(self, inner_model, dummy_batch, loss_fn, metrics): # TODO(b/124477598): the following set_session() should be removed in the # future. This is a workaround for Keras' caching sessions in a way that # isn't compatible with TFF. This is already fixed in TF master, but not as # of v1.13.1. # # We do not use .clear_session() because it blows away the graph stack by # resetting the default graph. tf.keras.backend.set_session(None) if hasattr(dummy_batch, '_asdict'): dummy_batch = dummy_batch._asdict() # Convert input to tensors, possibly from nested lists that need to be # converted to a single top-level tensor. dummy_tensors = collections.OrderedDict([ (k, tf.convert_to_tensor_or_sparse_tensor(v)) for k, v in six.iteritems(dummy_batch) ]) # NOTE: sub-classed `tf.keras.Model`s do not have fully initialized # variables until they are called on input. We forced that here. inner_model(dummy_tensors['x']) def _tensor_spec_with_undefined_batch_dim(tensor): # Remove the batch dimension and leave it unspecified. spec = tf.TensorSpec( shape=[None] + tensor.shape.dims[1:], dtype=tensor.dtype) return spec self._input_spec = nest.map_structure(_tensor_spec_with_undefined_batch_dim, dummy_tensors) self._keras_model = inner_model self._loss_fn = loss_fn self._metrics = metrics if metrics is not None else [] # This is defined here so that it closes over the `loss_fn`. class _WeightedMeanLossMetric(keras_metrics.Metric): """A `tf.keras.metrics.Metric` wrapper for the loss function.""" def __init__(self, name='loss', dtype=tf.float32): super(_WeightedMeanLossMetric, self).__init__(name, dtype) self._total_loss = self.add_weight('total_loss', initializer='zeros') self._total_weight = self.add_weight( 'total_weight', initializer='zeros') self._loss_fn = loss_fn def update_state(self, y_true, y_pred, sample_weight=None): y_true = tf.cast(y_true, self._dtype) y_pred = tf.cast(y_pred, self._dtype) # _loss_fn is expected to return the scalar mean loss, so we multiply by # the batch_size to get back to total loss. batch_size = tf.cast(tf.shape(y_pred)[0], self._dtype) batch_total_loss = self._loss_fn(y_true, y_pred) * batch_size op = self._total_loss.assign_add(batch_total_loss) with tf.control_dependencies([op]): return self._total_weight.assign_add(batch_size) def result(self): return tf.div_no_nan(self._total_loss, self._total_weight) self._loss_metric = _WeightedMeanLossMetric() metric_variable_type_dict = nest.map_structure(tf.TensorSpec.from_tensor, self.report_local_outputs()) federated_local_outputs_type = tff.FederatedType( metric_variable_type_dict, tff.CLIENTS, all_equal=False) def federated_output(local_outputs): results = collections.OrderedDict() for metric, variables in zip(self.get_metrics(), local_outputs): results[metric.name] = federated_aggregate_keras_metric( type(metric), metric.get_config(), variables) return results self._federated_output_computation = tff.federated_computation( federated_output, federated_local_outputs_type) # Keras creates variables that are not added to any collection, making it # impossible for TFF to extract them and create the appropriate initializer # before call a tff.Computation. Here we store them in a TFF specific # collection so that they can be retrieved later. # TODO(b/122081673): this likely goes away in TF2.0 for variable in itertools.chain(self.trainable_variables, self.non_trainable_variables, self.local_variables): tf.add_to_collection(graph_keys.GraphKeys.VARS_FOR_TFF_TO_INITIALIZE, variable)
def __init__(self, inner_model, dummy_batch, loss_fns, loss_weights=None, metrics=None): # NOTE: sub-classed `tf.keras.Model`s do not have fully initialized # variables until they are called on input. We forced that here. if isinstance(dummy_batch, collections.Mapping): inner_model(dummy_batch['x']) else: inner_model(dummy_batch[0]) def _tensor_spec_with_undefined_batch_dim(tensor): # Remove the batch dimension and leave it unspecified. spec = tf.TensorSpec(shape=[None] + tensor.shape.dims[1:], dtype=tensor.dtype) return spec self._input_spec = tf.nest.map_structure( _tensor_spec_with_undefined_batch_dim, dummy_batch) self._keras_model = inner_model self._loss_fns = loss_fns if isinstance(loss_weights, collections.Mapping): self._loss_weights = [] for name in inner_model.output_names: if name not in loss_weights: raise KeyError( 'Output missing from loss_weights dictionary' '\nloss_weights: {}\noutputs: {}'.format( list(loss_weights.keys()), inner_model.output_names)) else: self._loss_weights.append(loss_weights[name]) else: if loss_weights is None: self._loss_weights = [1.0 for _ in range(len(loss_fns))] else: self._loss_weights = loss_weights loss_weights = self._loss_weights self._metrics = metrics if metrics is not None else [] # This is defined here so that it closes over the `loss_fn`. class _WeightedMeanLossMetric(tf.keras.metrics.Mean): """A `tf.keras.metrics.Metric` wrapper for the loss function.""" def __init__(self, name='loss', dtype=tf.float32): super(_WeightedMeanLossMetric, self).__init__(name, dtype) self._loss_fns = loss_fns self._loss_weights = loss_weights def update_state(self, y_true, y_pred, sample_weight=None): if len(self._loss_fns) == 1: batch_size = tf.cast(tf.shape(y_pred)[0], self._dtype) y_true = tf.cast(y_true, self._dtype) y_pred = tf.cast(y_pred, self._dtype) batch_loss = self._loss_fns[0](y_true, y_pred) else: batch_loss = tf.zeros(()) for i in range(len(self._loss_fns)): y_t = tf.cast(y_true[i], self._dtype) y_p = tf.cast(y_pred[i], self._dtype) batch_loss += self._loss_weights[i] * self._loss_fns[ i](y_t, y_p) batch_size = tf.cast(tf.shape(y_pred[0])[0], self._dtype) return super(_WeightedMeanLossMetric, self).update_state(batch_loss, batch_size) class _TrainingTimeHistory(tf.keras.metrics.Sum): def update_state(self, y_true, y_pred, sample_weight=None): pass def log_time(self, time_value): return super(_TrainingTimeHistory, self).update_state(values=time_value) self._loss_metric = _WeightedMeanLossMetric() self._training_timing = _TrainingTimeHistory(name='training_time_sec') metric_variable_type_dict = tf.nest.map_structure( tf.TensorSpec.from_tensor, self.report_local_outputs()) federated_local_outputs_type = tff.FederatedType( metric_variable_type_dict, tff.CLIENTS) def federated_output(local_outputs): results = collections.OrderedDict() for metric, variables in zip(self.get_metrics(), local_outputs): results[metric.name] = federated_aggregate_keras_metric( type(metric), metric.get_config(), variables) return results self._federated_output_computation = tff.federated_computation( federated_output, federated_local_outputs_type) # Keras creates variables that are not added to any collection, making it # impossible for TFF to extract them and create the appropriate initializer # before call a tff.Computation. Here we store them in a TFF specific # collection so that they can be retrieved later. # TODO(b/122081673): this likely goes away in TF2.0 for variable in itertools.chain(self.trainable_variables, self.non_trainable_variables, self.local_variables): tf.compat.v1.add_to_collection( graph_keys.GraphKeys.VARS_FOR_TFF_TO_INITIALIZE, variable)
def federated_output_computation(self): return tff.federated_computation( lambda metrics: {'num_over': tff.federated_sum(metrics.num_over)})
def __init__(self, inner_model, dummy_batch, loss_fn, metrics): # NOTE: sub-classed `tf.keras.Model`s do not have fully initialized # variables until they are called on input. We forced that here. inner_model(dummy_batch['x']) def _tensor_spec_with_undefined_batch_dim(tensor): # Remove the batch dimension and leave it unspecified. spec = tf.TensorSpec( shape=[None] + tensor.shape.dims[1:], dtype=tensor.dtype) return spec self._input_spec = tf.nest.map_structure( _tensor_spec_with_undefined_batch_dim, dummy_batch) self._keras_model = inner_model self._loss_fn = loss_fn self._metrics = metrics if metrics is not None else [] # This is defined here so that it closes over the `loss_fn`. class _WeightedMeanLossMetric(tf.keras.metrics.Mean): """A `tf.keras.metrics.Metric` wrapper for the loss function.""" def __init__(self, name='loss', dtype=tf.float32): super(_WeightedMeanLossMetric, self).__init__(name, dtype) self._loss_fn = loss_fn def update_state(self, y_true, y_pred, sample_weight=None): y_true = tf.cast(y_true, self._dtype) y_pred = tf.cast(y_pred, self._dtype) batch_size = tf.cast(tf.shape(y_pred)[0], self._dtype) batch_loss = self._loss_fn(y_true, y_pred) return super(_WeightedMeanLossMetric, self).update_state(batch_loss, batch_size) self._loss_metric = _WeightedMeanLossMetric() metric_variable_type_dict = tf.nest.map_structure( tf.TensorSpec.from_tensor, self.report_local_outputs()) federated_local_outputs_type = tff.FederatedType(metric_variable_type_dict, tff.CLIENTS) def federated_output(local_outputs): results = collections.OrderedDict() for metric, variables in zip(self.get_metrics(), local_outputs): results[metric.name] = federated_aggregate_keras_metric( type(metric), metric.get_config(), variables) return results self._federated_output_computation = tff.federated_computation( federated_output, federated_local_outputs_type) # Keras creates variables that are not added to any collection, making it # impossible for TFF to extract them and create the appropriate initializer # before call a tff.Computation. Here we store them in a TFF specific # collection so that they can be retrieved later. # TODO(b/122081673): this likely goes away in TF2.0 for variable in itertools.chain(self.trainable_variables, self.non_trainable_variables, self.local_variables): tf.add_to_collection(graph_keys.GraphKeys.VARS_FOR_TFF_TO_INITIALIZE, variable)
def __init__(self, inner_model, input_spec, loss_fns, loss_weights=None, metrics=None): self._input_spec = input_spec if not loss_fns: raise ValueError( 'Must specify at least one loss_fns, got: {l}'.format( l=loss_fns)) if (bool(len(loss_fns) == 1) != tf.is_tensor(inner_model.output) or (isinstance(inner_model.output, list) and len(loss_fns) != len(inner_model.output))): raise ValueError( 'Must specify the same number of loss_fns as model ' 'outputs.\nloss_fns: {l}\nmodel outputs: {o}'.format( l=loss_fns, o=inner_model.output)) self._loss_fns = loss_fns if loss_weights is None: loss_weights = [1.0] * len(loss_fns) else: py_typecheck.check_type(loss_weights, collections.Sequence) if len(loss_weights) != len(loss_fns): raise ValueError( 'Must specify the same number of ' 'loss_weights (got {llw}) as loss_fns (got {llf}).\n' 'loss_weights: {lw}\nloss_fns: {lf}'.format( lw=loss_weights, llw=len(loss_weights), lf=loss_fns, llf=len(loss_fns))) self._loss_weights = loss_weights self._keras_model = inner_model self._metrics = metrics if metrics is not None else [] # This is defined here so that it closes over the `loss_fn`. class _WeightedMeanLossMetric(tf.keras.metrics.Mean): """A `tf.keras.metrics.Metric` wrapper for the loss function.""" def __init__(self, name='loss', dtype=tf.float32): super().__init__(name, dtype) self._loss_fns = loss_fns self._loss_weights = loss_weights def update_state(self, y_true, y_pred, sample_weight=None): if len(self._loss_fns) == 1: batch_size = tf.cast(tf.shape(y_pred)[0], self._dtype) y_true = tf.cast(y_true, self._dtype) y_pred = tf.cast(y_pred, self._dtype) batch_loss = self._loss_fns[0](y_true, y_pred) else: batch_loss = tf.zeros(()) for i in range(len(self._loss_fns)): y_t = tf.cast(y_true[i], self._dtype) y_p = tf.cast(y_pred[i], self._dtype) batch_loss += self._loss_weights[i] * self._loss_fns[ i](y_t, y_p) batch_size = tf.cast(tf.shape(y_pred[0])[0], self._dtype) return super().update_state(batch_loss, batch_size) self._loss_metric = _WeightedMeanLossMetric() metric_variable_type_dict = tf.nest.map_structure( tf.TensorSpec.from_tensor, self.report_local_outputs()) federated_local_outputs_type = tff.FederatedType( metric_variable_type_dict, tff.CLIENTS) def federated_output(local_outputs): results = collections.OrderedDict() for metric, variables in zip(self.get_metrics(), local_outputs): results[metric.name] = federated_aggregate_keras_metric( type(metric), metric.get_config(), variables) return results self._federated_output_computation = tff.federated_computation( federated_output, federated_local_outputs_type) # Keras creates variables that are not added to any collection, making it # impossible for TFF to extract them and create the appropriate initializer # before call a tff.Computation. Here we store them in a TFF specific # collection so that they can be retrieved later. # TODO(b/122081673): this likely goes away in TF2.0 for variable in itertools.chain(self.trainable_variables, self.non_trainable_variables, self.local_variables): tf.compat.v1.add_to_collection( graph_keys.GraphKeys.VARS_FOR_TFF_TO_INITIALIZE, variable)
def federated_output_computation(self): return tff_core.federated_computation(lambda x: x)
def __init__(self, inner_model, input_spec, loss_fns, loss_weights=None, metrics=None): self._input_spec = input_spec if not loss_fns: raise ValueError( 'Must specify at least one loss_fns, got: {l}'.format( l=loss_fns)) if (len(tf.nest.flatten(loss_fns)) != len( tf.nest.flatten(inner_model.output))): raise ValueError( 'Must specify the same number of loss_fns as model ' 'outputs.\nloss_fns: {l}\nmodel outputs: {o}'.format( l=loss_fns, o=inner_model.output)) self._loss_fns = loss_fns if loss_weights is None: loss_weights = [1.0] * len(loss_fns) else: py_typecheck.check_type(loss_weights, collections.Sequence) if len(loss_weights) != len(loss_fns): raise ValueError( 'Must specify the same number of ' 'loss_weights (got {llw}) as loss_fns (got {llf}).\n' 'loss_weights: {lw}\nloss_fns: {lf}'.format( lw=loss_weights, llw=len(loss_weights), lf=loss_fns, llf=len(loss_fns))) self._loss_weights = loss_weights self._keras_model = inner_model self._metrics = metrics if metrics is not None else [] # This is defined here so that it closes over the `loss_fn`. class _WeightedMeanLossMetric(tf.keras.metrics.Mean): """A `tf.keras.metrics.Metric` wrapper for the loss function.""" def __init__(self, name='loss', dtype=tf.float32): super().__init__(name, dtype) self._loss_fns = loss_fns self._loss_weights = loss_weights def update_state(self, y_true, y_pred, sample_weight=None): if len(self._loss_fns) == 1: batch_size = tf.cast(tf.shape(y_pred)[0], self._dtype) y_true = tf.cast(y_true, self._dtype) y_pred = tf.cast(y_pred, self._dtype) batch_loss = self._loss_fns[0](y_true, y_pred) else: batch_loss = tf.zeros(()) for i in range(len(self._loss_fns)): y_t = tf.cast(y_true[i], self._dtype) y_p = tf.cast(y_pred[i], self._dtype) batch_loss += self._loss_weights[i] * self._loss_fns[ i](y_t, y_p) batch_size = tf.cast(tf.shape(y_pred[0])[0], self._dtype) return super().update_state(batch_loss, batch_size) self._loss_metric = _WeightedMeanLossMetric() metric_variable_type_dict = tf.nest.map_structure( tf.TensorSpec.from_tensor, self.report_local_outputs()) federated_local_outputs_type = tff.FederatedType( metric_variable_type_dict, tff.CLIENTS) def federated_output(local_outputs): return federated_aggregate_keras_metric(self.get_metrics(), local_outputs) self._federated_output_computation = tff.federated_computation( federated_output, federated_local_outputs_type)