Beispiel #1
0
    def client_model_fn(weights: tff.learning.ModelWeights,
                        client_state: client.State) -> tff.learning.Model:
        model = model_fn()
        weights.assign_weights_to(model.base_model)
        client_state.model.assign_weights_to(model.personalized_model)

        return model
Beispiel #2
0
  def evaluate_fn(reference_model: tff.learning.ModelWeights) -> Dict[str, Any]:
    """Evaluation function to be used during training."""

    if not isinstance(reference_model, tff.learning.ModelWeights):
      raise TypeError('The reference model used for evaluation must be a'
                      '`tff.learning.ModelWeights` instance.')

    keras_model = compiled_eval_keras_model()
    reference_model.assign_weights_to(keras_model)
    logging.info('Evaluating the current model')
    for i,client_id in enumerate(client_ids):
      if i==0:
        results={}
        client_data = federated_eval_dataset.create_tf_dataset_for_client(client_id)
        eval_tuple_dataset = convert_to_tuple_dataset(client_data)
        eval_metrics = keras_model.evaluate(eval_tuple_dataset, verbose=0)
        for i,name in enumerate(keras_model.metrics_names):
          results[name]=[]
          results[name].append(eval_metrics[i])
      else:
        client_data = federated_eval_dataset.create_tf_dataset_for_client(client_id)
        eval_tuple_dataset = convert_to_tuple_dataset(client_data)
        eval_metrics = keras_model.evaluate(eval_tuple_dataset, verbose=0)
        for i,name in enumerate(keras_model.metrics_names):
          results[name].append(eval_metrics[i])
    statistics_dict = {}
    for name in keras_model.metrics_names:
      statistics_dict[f'avg_{name}'] = np.mean(results[name])
      statistics_dict[f'min_{name}'] = np.min(results[name])
      statistics_dict[f'max_{name}'] = np.max(results[name])
      statistics_dict[f'std_{name}']= np.std(results[name])
    return statistics_dict
Beispiel #3
0
  def evaluate_fn(reference_model: tff.learning.ModelWeights) -> Dict[str, Any]:
    """Evaluation function to be used during training."""

    if not isinstance(reference_model, tff.learning.ModelWeights):
      raise TypeError('The reference model used for evaluation must be a'
                      '`tff.learning.ModelWeights` instance.')

    keras_model = compiled_eval_keras_model()
    reference_model.assign_weights_to(keras_model)
    logging.info('Evaluating the current model')
    eval_metrics = keras_model.evaluate(eval_tuple_dataset, verbose=0)
    return dict(zip(keras_model.metrics_names, eval_metrics))
Beispiel #4
0
def validate(dataset: tf.data.Dataset, state: State,
             weights: tff.learning.ModelWeights,
             model_fn: Callable) -> Validation:
    with tf.init_scope():
        model = model_fn(pos_weight=state.client_pos_weight)

    weights.assign_weights_to(model)

    for batch in dataset:
        model.forward_pass(batch, training=False)

    return Validation(metrics=model.report_local_outputs())
Beispiel #5
0
    def assign_weights_to_keras_model(
            cls, reference_model: tff.learning.ModelWeights,
            keras_model: tf.keras.Model):
        """Assign the model weights to the weights of a `tf.keras.Model`.

    Args:
      reference_model: the `tff.learning.ModelWeights` object to assign weights
        from.
      keras_model: the `tf.keras.Model` object to assign weights to.
    """
        if not isinstance(reference_model, tff.learning.ModelWeights):
            raise TypeError('The reference model must be an instance of '
                            'tff.learning.ModelWeights.')
        reference_model.assign_weights_to(keras_model)
Beispiel #6
0
def evaluate(dataset: tf.data.Dataset, state: State,
             weights: tff.learning.ModelWeights,
             model_fn: Callable) -> Evaluation:
    with tf.init_scope():
        model = model_fn(pos_weight=state.client_pos_weight)

    weights.assign_weights_to(model)

    def evaluation_fn(state, batch):
        outputs = model.forward_pass(batch, training=False)

        y_true = tf.reshape(batch[1], (-1, ))
        y_pred = tf.round(
            tf.nn.sigmoid(tf.reshape(outputs.predictions, (-1, ))))

        return state + tf.math.confusion_matrix(y_true, y_pred, num_classes=2)

    confusion_matrix = dataset.reduce(tf.zeros((2, 2), dtype=tf.int32),
                                      evaluation_fn)

    return Evaluation(confusion_matrix=confusion_matrix,
                      metrics=model.report_local_outputs())