def __evaluate( weights: tff.learning.ModelWeights, client_states: Dict[int, ClientState], dataset: FederatedDataset, evaluation_fn: Callable[ [tff.learning.ModelWeights, List[tf.data.Dataset], List[ClientState]], Tuple[tf.Tensor, Dict[Text, tf.Tensor], Dict[Text, tf.Tensor]]] ) -> None: confusion_matrix, aggregated_metrics, client_metrics = evaluation_fn( weights, [dataset.data[client] for client in dataset.clients], [client_states[client] for client in dataset.clients]) # Confusion Matrix fig, ax = plt.subplots(figsize=(16, 8)) sns.heatmap(confusion_matrix, annot=True, fmt='d', cmap=sns.color_palette("Blues"), ax=ax) ax.set_xlabel('Predicted') ax.set_ylabel('Ground Truth') mlflow.log_figure(fig, f'confusion_matrix.png') plt.close(fig) # Precision Recall fig, ax = plt.subplots(figsize=(16, 8)) sns.lineplot(x=aggregated_metrics['recall'], y=aggregated_metrics['precision'], ax=ax) ax.set_xlabel('Recall') ax.set_xlim(0., 1.) ax.set_ylabel('Precision') ax.set_ylim(0., 1.) mlflow.log_figure(fig, f'precision_recall.png') plt.close(fig) # Client Metrics auc = metrics.SigmoidDecorator(tf.keras.metrics.AUC(curve='PR'), name='auc') accuracy = metrics.SigmoidDecorator(tf.keras.metrics.BinaryAccuracy(), name='accuracy') for client, metric in zip(client_states.keys(), iter(client_metrics)): tf.nest.map_structure(lambda v, t: v.assign(t), auc.variables, list(metric['auc'])) tf.nest.map_structure(lambda v, t: v.assign(t), accuracy.variables, list(metric['accuracy'])) mlflow.log_metric(f'client_{client}_val_auc', auc.result().numpy()) mlflow.log_metric(f'client_{client}_val_acc', accuracy.result().numpy())
def __evaluation_metrics_fn() -> List[tf.keras.metrics.Metric]: thresholds = list(np.linspace(0, 1, 200, endpoint=False)) return [ metrics.SigmoidDecorator(tf.keras.metrics.AUC(curve='PR'), name='auc'), metrics.SigmoidDecorator( tf.keras.metrics.Precision(thresholds=thresholds), name='precision'), metrics.SigmoidDecorator( tf.keras.metrics.Recall(thresholds=thresholds), name='recall'), metrics.SigmoidDecorator(tf.keras.metrics.BinaryAccuracy(), name='accuracy'), ]
def __training_metrics_fn() -> List[tf.keras.metrics.Metric]: return [ metrics.SigmoidDecorator(tf.keras.metrics.AUC(curve='PR'), name='auc') ]