示例#1
0
def __evaluate(
    weights: tff.learning.ModelWeights, client_states: Dict[int, ClientState],
    dataset: FederatedDataset, evaluation_fn: Callable[
        [tff.learning.ModelWeights, List[tf.data.Dataset], List[ClientState]],
        Tuple[tf.Tensor, Dict[Text, tf.Tensor], Dict[Text, tf.Tensor]]]
) -> None:
    confusion_matrix, aggregated_metrics, client_metrics = evaluation_fn(
        weights, [dataset.data[client] for client in dataset.clients],
        [client_states[client] for client in dataset.clients])

    # Confusion Matrix
    fig, ax = plt.subplots(figsize=(16, 8))

    sns.heatmap(confusion_matrix,
                annot=True,
                fmt='d',
                cmap=sns.color_palette("Blues"),
                ax=ax)

    ax.set_xlabel('Predicted')
    ax.set_ylabel('Ground Truth')

    mlflow.log_figure(fig, f'confusion_matrix.png')
    plt.close(fig)

    # Precision Recall
    fig, ax = plt.subplots(figsize=(16, 8))

    sns.lineplot(x=aggregated_metrics['recall'],
                 y=aggregated_metrics['precision'],
                 ax=ax)

    ax.set_xlabel('Recall')
    ax.set_xlim(0., 1.)

    ax.set_ylabel('Precision')
    ax.set_ylim(0., 1.)

    mlflow.log_figure(fig, f'precision_recall.png')
    plt.close(fig)

    # Client Metrics
    auc = metrics.SigmoidDecorator(tf.keras.metrics.AUC(curve='PR'),
                                   name='auc')
    accuracy = metrics.SigmoidDecorator(tf.keras.metrics.BinaryAccuracy(),
                                        name='accuracy')

    for client, metric in zip(client_states.keys(), iter(client_metrics)):
        tf.nest.map_structure(lambda v, t: v.assign(t), auc.variables,
                              list(metric['auc']))
        tf.nest.map_structure(lambda v, t: v.assign(t), accuracy.variables,
                              list(metric['accuracy']))

        mlflow.log_metric(f'client_{client}_val_auc', auc.result().numpy())
        mlflow.log_metric(f'client_{client}_val_acc',
                          accuracy.result().numpy())
示例#2
0
def __evaluation_metrics_fn() -> List[tf.keras.metrics.Metric]:
    thresholds = list(np.linspace(0, 1, 200, endpoint=False))
    return [
        metrics.SigmoidDecorator(tf.keras.metrics.AUC(curve='PR'), name='auc'),
        metrics.SigmoidDecorator(
            tf.keras.metrics.Precision(thresholds=thresholds),
            name='precision'),
        metrics.SigmoidDecorator(
            tf.keras.metrics.Recall(thresholds=thresholds), name='recall'),
        metrics.SigmoidDecorator(tf.keras.metrics.BinaryAccuracy(),
                                 name='accuracy'),
    ]
示例#3
0
def __training_metrics_fn() -> List[tf.keras.metrics.Metric]:
    return [
        metrics.SigmoidDecorator(tf.keras.metrics.AUC(curve='PR'), name='auc')
    ]