def client_model_fn(weights: tff.learning.ModelWeights, client_state: client.State) -> tff.learning.Model: model = model_fn() weights.assign_weights_to(model.base_model) client_state.model.assign_weights_to(model.personalized_model) return model
def evaluate_fn(reference_model: tff.learning.ModelWeights) -> Dict[str, Any]: """Evaluation function to be used during training.""" if not isinstance(reference_model, tff.learning.ModelWeights): raise TypeError('The reference model used for evaluation must be a' '`tff.learning.ModelWeights` instance.') keras_model = compiled_eval_keras_model() reference_model.assign_weights_to(keras_model) logging.info('Evaluating the current model') for i,client_id in enumerate(client_ids): if i==0: results={} client_data = federated_eval_dataset.create_tf_dataset_for_client(client_id) eval_tuple_dataset = convert_to_tuple_dataset(client_data) eval_metrics = keras_model.evaluate(eval_tuple_dataset, verbose=0) for i,name in enumerate(keras_model.metrics_names): results[name]=[] results[name].append(eval_metrics[i]) else: client_data = federated_eval_dataset.create_tf_dataset_for_client(client_id) eval_tuple_dataset = convert_to_tuple_dataset(client_data) eval_metrics = keras_model.evaluate(eval_tuple_dataset, verbose=0) for i,name in enumerate(keras_model.metrics_names): results[name].append(eval_metrics[i]) statistics_dict = {} for name in keras_model.metrics_names: statistics_dict[f'avg_{name}'] = np.mean(results[name]) statistics_dict[f'min_{name}'] = np.min(results[name]) statistics_dict[f'max_{name}'] = np.max(results[name]) statistics_dict[f'std_{name}']= np.std(results[name]) return statistics_dict
def evaluate_fn(reference_model: tff.learning.ModelWeights) -> Dict[str, Any]: """Evaluation function to be used during training.""" if not isinstance(reference_model, tff.learning.ModelWeights): raise TypeError('The reference model used for evaluation must be a' '`tff.learning.ModelWeights` instance.') keras_model = compiled_eval_keras_model() reference_model.assign_weights_to(keras_model) logging.info('Evaluating the current model') eval_metrics = keras_model.evaluate(eval_tuple_dataset, verbose=0) return dict(zip(keras_model.metrics_names, eval_metrics))
def validate(dataset: tf.data.Dataset, state: State, weights: tff.learning.ModelWeights, model_fn: Callable) -> Validation: with tf.init_scope(): model = model_fn(pos_weight=state.client_pos_weight) weights.assign_weights_to(model) for batch in dataset: model.forward_pass(batch, training=False) return Validation(metrics=model.report_local_outputs())
def assign_weights_to_keras_model( cls, reference_model: tff.learning.ModelWeights, keras_model: tf.keras.Model): """Assign the model weights to the weights of a `tf.keras.Model`. Args: reference_model: the `tff.learning.ModelWeights` object to assign weights from. keras_model: the `tf.keras.Model` object to assign weights to. """ if not isinstance(reference_model, tff.learning.ModelWeights): raise TypeError('The reference model must be an instance of ' 'tff.learning.ModelWeights.') reference_model.assign_weights_to(keras_model)
def evaluate(dataset: tf.data.Dataset, state: State, weights: tff.learning.ModelWeights, model_fn: Callable) -> Evaluation: with tf.init_scope(): model = model_fn(pos_weight=state.client_pos_weight) weights.assign_weights_to(model) def evaluation_fn(state, batch): outputs = model.forward_pass(batch, training=False) y_true = tf.reshape(batch[1], (-1, )) y_pred = tf.round( tf.nn.sigmoid(tf.reshape(outputs.predictions, (-1, )))) return state + tf.math.confusion_matrix(y_true, y_pred, num_classes=2) confusion_matrix = dataset.reduce(tf.zeros((2, 2), dtype=tf.int32), evaluation_fn) return Evaluation(confusion_matrix=confusion_matrix, metrics=model.report_local_outputs())