Beispiel #1
0
import tensorflow as tf
import tensorflow_federated as tff

tf.enable_resource_variables()


@tff.federated_computation
def hello_word():
    return "Hello, World!"


print(hello_word())

# %%
federated_float_on_clients = tff.FederatedType(tf.float32, tff.CLIENTS)
print(str(federated_float_on_clients.member))
print(str(federated_float_on_clients.placement))
print(str(federated_float_on_clients))
print(federated_float_on_clients.all_equal)
print(tff.FederatedType(tf.float32, tff.CLIENTS, all_equal=True))

# %%
simple_regression_model_type = (
    tff.NamedTupleType([('a', tf.float32), ('b', tf.float32)])
)
print(str(simple_regression_model_type))
print(str(tff.FederatedType(simple_regression_model_type, tff.CLIENTS, all_equal=True)))


# %%
Beispiel #2
0
    model = model_fn()
    client_optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)
    loss = tf.Variable(0.0, trainable=False, dtype=tf.float32)
    return client_update(model, tf_dataset, server_weights, client_optimizer,
                         loss)


# 将服务器更新代码转为tff代码
@tff.tf_computation(model_weights_type)
def server_update_fn(mean_client_weights):
    model = model_fn()
    return server_update(model, mean_client_weights)


# 将数据集结构和模型参数结构转为联邦结构
federated_server_type = tff.FederatedType(model_weights_type, tff.SERVER)
federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS)


# 联邦学习过程
@tff.federated_computation(federated_server_type, federated_dataset_type)
def next_fn(server_weights, federated_dataset):
    # 将服务器模型广播到客户端上
    server_weights_at_client = tff.federated_broadcast(server_weights)

    # 客户端计算更新过程,并更新参数
    client_weights, clients_loss = tff.federated_map(
        client_update_fn, (federated_dataset, server_weights_at_client))

    # 服务器平均所有客户端更新的模型参数
    mean_client_weights = tff.federated_mean(client_weights)
Beispiel #3
0
def build_run_one_round_fn_attacked(server_update_fn, client_update_fn,
                                    stateful_delta_aggregate_fn,
                                    dummy_model_for_metadata,
                                    federated_server_state_type,
                                    federated_dataset_type):
  """Builds a `tff.federated_computation` for a round of training.

  Args:
    server_update_fn: A function for updates in the server.
    client_update_fn: A function for updates in the clients.
    stateful_delta_aggregate_fn: A 'tff.computation'that takes in model deltas
      placed@CLIENTS to an aggregated model delta placed@SERVER.
    dummy_model_for_metadata: A dummy `tff.learning.Model`.
    federated_server_state_type: type_signature of federated server state.
    federated_dataset_type: type_signature of federated dataset.

  Returns:
    A `tff.federated_computation` for a round of training.
  """

  federated_bool_type = tff.FederatedType(tf.bool, tff.CLIENTS)

  @tff.federated_computation(federated_server_state_type,
                             federated_dataset_type, federated_dataset_type,
                             federated_bool_type)
  def run_one_round(server_state, federated_dataset, malicious_dataset,
                    malicious_clients):
    """Orchestration logic for one round of computation.

    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`.
      malicious_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`.
        consisting of malicious datasets.
      malicious_clients: A federated `tf.bool` with placement `tff.CLIENTS`.

    Returns:
      A tuple of updated `ServerState` and the result of
      `tff.learning.Model.federated_output_computation`.
    """

    client_model = tff.federated_broadcast(server_state.model)

    client_outputs = tff.federated_map(
        client_update_fn,
        (federated_dataset, malicious_dataset, malicious_clients, client_model))

    weight_denom = client_outputs.weights_delta_weight

    new_delta_aggregate_state, round_model_delta = stateful_delta_aggregate_fn(
        server_state.delta_aggregate_state,
        client_outputs.weights_delta,
        weight=weight_denom)

    server_state = tff.federated_map(
        server_update_fn,
        (server_state, round_model_delta, new_delta_aggregate_state))

    aggregated_outputs = dummy_model_for_metadata.federated_output_computation(
        client_outputs.model_output)
    if isinstance(aggregated_outputs.type_signature, tff.StructType):
      aggregated_outputs = tff.federated_zip(aggregated_outputs)

    return server_state, aggregated_outputs

  return run_one_round
Beispiel #4
0
def build_fixed_clip_norm_mean_process(
    *,
    clip_norm: float,
    model_update_type: Union[tff.NamedTupleType, tff.TensorType],
) -> tff.templates.MeasuredProcess:
  """Returns process that clips the client deltas before averaging.

  The returned `MeasuredProcess` has a next function with the TFF type
  signature:

  ```
  (<()@SERVER, {model_update_type}@CLIENTS> ->
   <state=()@SERVER,
    result=model_update_type@SERVER,
    measurements=NormClippedAggregationMetrics@SERVER>)
  ```

  Args:
    clip_norm: the clip norm to apply to the global norm of the model update.
      See https://www.tensorflow.org/api_docs/python/tf/clip_by_global_norm for
        details.
    model_update_type: a `tff.Type` describing the shape and type of the value
      that will be clipped and averaged.

  Returns:
    A `tff.templates.MeasuredProcess` with the type signature detailed above.
  """

  @tff.federated_computation
  def initialize_fn():
    return tff.federated_value((), tff.SERVER)

  @tff.federated_computation(
      tff.FederatedType((), tff.SERVER),
      tff.FederatedType(model_update_type, tff.CLIENTS),
      tff.FederatedType(tf.float32, tff.CLIENTS))
  def next_fn(state, deltas, weights):

    @tff.tf_computation(model_update_type)
    def clip_by_global_norm(update):
      clipped_update, global_norm = tf.clip_by_global_norm(
          tf.nest.flatten(update), tf.constant(clip_norm))
      was_clipped = tf.cond(
          tf.greater(global_norm, tf.constant(clip_norm)),
          lambda: tf.constant(1),
          lambda: tf.constant(0),
      )
      clipped_update = tf.nest.pack_sequence_as(update, clipped_update)
      return clipped_update, global_norm, was_clipped

    clipped_deltas, client_norms, client_was_clipped = tff.federated_map(
        clip_by_global_norm, deltas)

    return collections.OrderedDict(
        state=state,
        result=tff.federated_mean(clipped_deltas, weight=weights),
        measurements=tff.federated_zip(
            NormClippedAggregationMetrics(
                max_global_norm=tff.utils.federated_max(client_norms),
                num_clipped=tff.federated_sum(client_was_clipped),
            )))

  return tff.templates.MeasuredProcess(
      initialize_fn=initialize_fn, next_fn=next_fn)
Beispiel #5
0
def build_federated_averaging_process(
    model_fn,
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0),
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.1)):
    """Builds the TFF computations for optimization using federated averaging.

  Args:
    model_fn: A no-arg function that returns a
      `simple_fedavg_tf.KerasModelWrapper`.
    server_optimizer_fn: A no-arg function that returns a
      `tf.keras.optimizers.Optimizer` for server update.
    client_optimizer_fn: A no-arg function that returns a
      `tf.keras.optimizers.Optimizer` for client update.

  Returns:
    A `tff.utils.IterativeProcess`.
  """

    dummy_model = model_fn()

    @tff.tf_computation
    def server_init_tf():
        model = model_fn()
        server_optimizer = server_optimizer_fn()
        _initialize_optimizer_vars(model, server_optimizer)
        return ServerState(model_weights=model.weights,
                           optimizer_state=server_optimizer.variables(),
                           round_num=0)

    server_state_type = server_init_tf.type_signature.result

    model_weights_type = server_state_type.model_weights

    @tff.tf_computation(server_state_type, model_weights_type.trainable)
    def server_update_fn(server_state, model_delta):
        model = model_fn()
        server_optimizer = server_optimizer_fn()
        _initialize_optimizer_vars(model, server_optimizer)
        return server_update(model, server_optimizer, server_state,
                             model_delta)

    @tff.tf_computation(server_state_type)
    def server_message_fn(server_state):
        return build_server_broadcast_message(server_state)

    server_message_type = server_message_fn.type_signature.result
    tf_dataset_type = tff.SequenceType(dummy_model.input_spec)

    @tff.tf_computation(tf_dataset_type, server_message_type)
    def client_update_fn(tf_dataset, server_message):
        model = model_fn()
        client_optimizer = client_optimizer_fn()
        return client_update(model, tf_dataset, server_message,
                             client_optimizer)

    federated_server_state_type = tff.FederatedType(server_state_type,
                                                    tff.SERVER)
    federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS)

    @tff.federated_computation(federated_server_state_type,
                               federated_dataset_type)
    def run_one_round(server_state, federated_dataset):
        """Orchestration logic for one round of computation.

    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.data.Dataset` with placement
        `tff.CLIENTS`.

    Returns:
      A tuple of updated `ServerState` and `tf.Tensor` of average loss.
    """
        server_message = tff.federated_map(server_message_fn, server_state)
        server_message_at_client = tff.federated_broadcast(server_message)

        client_outputs = tff.federated_map(
            client_update_fn, (federated_dataset, server_message_at_client))

        weight_denom = client_outputs.client_weight
        round_model_delta = tff.federated_mean(client_outputs.weights_delta,
                                               weight=weight_denom)

        server_state = tff.federated_map(server_update_fn,
                                         (server_state, round_model_delta))
        round_loss_metric = tff.federated_mean(client_outputs.model_output,
                                               weight=weight_denom)

        return server_state, round_loss_metric

    @tff.federated_computation
    def server_init_tff():
        """Orchestration logic for server model initialization."""
        return tff.federated_value(server_init_tf(), tff.SERVER)

    return tff.utils.IterativeProcess(initialize_fn=server_init_tff,
                                      next_fn=run_one_round)
        #apply the gradient using client optimizer
        client_optimizer.apply_gradients(grads_and_vars)

    return client_weights


@tff.tf_computation(tf_dataset_type, model_weights_type)
def client_update_fn(tf_dataset, server_weights):
    tff_model = wrap_model_with_tff(keras_model(), input_spec)
    client_optimizer = tf.keras.optimizers.Adam()
    return client_update(tff_model, tf_dataset, server_weights,
                         client_optimizer)


federated_server_type = tff.FederatedType(model_weights_type, tff.SERVER)
federated_dataset_data = tff.FederatedType(tf_dataset_type, tff.CLIENTS)


@tff.federated_computation(federated_server_type, federated_dataset_data)
def next_fn(server_weights, federated_dataset):
    # Send server weights to clients
    server_weights_to_clients = tff.federated_broadcast(server_weights)

    # Each client computes their updated weights
    client_weights = tff.federated_map(
        client_update_fn, (federated_dataset, server_weights_to_clients))

    # Client mean
    mean_client_weights = tff.federated_mean(client_weights)
Beispiel #7
0
def create_trainer(batch_size, step_size):
    """Constructs a trainer for the given batch size.

  Args:
    batch_size: The size of a single data batch.
    step_size: The step size to use during training.

  Returns:
    An instance of `Trainer`.
  """
    batch_type = tff.to_type(
        collections.OrderedDict([
            ('pixels', tff.TensorType(np.float32, (batch_size, 784))),
            ('labels', tff.TensorType(np.int32, (batch_size, )))
        ]))

    model_type = tff.to_type(
        collections.OrderedDict([('weights',
                                  tff.TensorType(np.float32, (784, 10))),
                                 ('bias', tff.TensorType(np.float32,
                                                         (10, )))]))

    @tff.experimental.jax_computation
    def create_zero_model():
        weights = jax.numpy.zeros((784, 10), dtype=np.float32)
        bias = jax.numpy.zeros((10, ), dtype=np.float32)
        return collections.OrderedDict([('weights', weights), ('bias', bias)])

    def generate_random_batches(num_batches):
        for _ in range(num_batches):
            pixels = np.random.uniform(low=0.0,
                                       high=1.0,
                                       size=(batch_size,
                                             784)).astype(np.float32)
            labels = np.random.randint(low=0,
                                       high=9,
                                       size=(batch_size, ),
                                       dtype=np.int32)
            yield collections.OrderedDict([('pixels', pixels),
                                           ('labels', labels)])

    def _loss_fn(model, batch):
        y = jax.nn.softmax(
            jax.numpy.add(jax.numpy.matmul(batch['pixels'], model['weights']),
                          model['bias']))
        targets = jax.nn.one_hot(jax.numpy.reshape(batch['labels'], -1), 10)
        return -jax.numpy.mean(
            jax.numpy.sum(targets * jax.numpy.log(y), axis=1))

    @tff.experimental.jax_computation(model_type, batch_type)
    def train_on_one_batch(model, batch):
        grads = jax.api.grad(_loss_fn)(model, batch)
        return collections.OrderedDict([(k, model[k] - step_size * grads[k])
                                        for k in ['weights', 'bias']])

    @tff.federated_computation(model_type, tff.SequenceType(batch_type))
    def train_on_one_client(model, batches):
        return tff.sequence_reduce(batches, model, train_on_one_batch)

    local_training_process = tff.templates.IterativeProcess(
        initialize_fn=create_zero_model, next_fn=train_on_one_client)

    # TODO(b/175888145): Switch to a simple tff.federated_mean after finding a
    # way to reduce reliance on the auto-generated TF bits in the executor stack
    # for the GENERIC_PLUS and similar intrinsics.

    @tff.experimental.jax_computation
    def create_zero_count():
        return np.int32(0)

    @tff.experimental.jax_computation
    def create_one_count():
        return np.int32(1)

    @tff.experimental.jax_computation(model_type, model_type)
    def combine_two_models(x, y):
        return collections.OrderedDict([
            ('weights', jax.numpy.add(x['weights'], y['weights'])),
            ('bias', jax.numpy.add(x['bias'], y['bias']))
        ])

    @tff.experimental.jax_computation(model_type, np.int32)
    def divide_model_by_count(model, count):
        multiplier = 1.0 / count.astype(np.float32)
        return collections.OrderedDict([
            ('weights', jax.numpy.multiply(model['weights'], multiplier)),
            ('bias', jax.numpy.multiply(model['bias'], multiplier))
        ])

    @tff.experimental.jax_computation(np.int32, np.int32)
    def combine_two_counts(x, y):
        return jax.numpy.add(x, y)

    @tff.federated_computation
    def make_zero_model_and_count():
        return collections.OrderedDict([('model', create_zero_model()),
                                        ('count', create_zero_count())])

    model_and_count_type = make_zero_model_and_count.type_signature.result

    @tff.federated_computation(model_and_count_type, model_type)
    def accumulate(arg):
        # TODO(b/175888145): Diagnose the newly emergent problem with tuple arg
        # handling that gets in the way by forcing named elements here at input
        # (i.e., we can't just declare `def accumulate(accumulator, model)` for
        # reasons that yet need to be understood).
        accumulator = arg[0]
        model = arg[1]
        return collections.OrderedDict([
            ('model', combine_two_models(accumulator['model'], model)),
            ('count',
             combine_two_counts(accumulator['count'], create_one_count()))
        ])

    @tff.federated_computation(model_and_count_type, model_and_count_type)
    def merge(arg):
        x = arg[0]
        y = arg[1]
        return collections.OrderedDict([
            ('model', combine_two_models(x['model'], y['model'])),
            ('count', combine_two_counts(x['count'], y['count']))
        ])

    @tff.federated_computation(model_and_count_type)
    def report(x):
        return divide_model_by_count(x['model'], x['count'])

    @tff.federated_computation
    def create_zero_model_on_server():
        return tff.federated_eval(create_zero_model, tff.SERVER)

    @tff.federated_computation(tff.FederatedType(model_type, tff.SERVER),
                               tff.FederatedType(tff.SequenceType(batch_type),
                                                 tff.CLIENTS))
    def train_one_round(model, federated_data):
        locally_trained_models = tff.federated_map(
            train_on_one_client,
            collections.OrderedDict([('model', tff.federated_broadcast(model)),
                                     ('batches', federated_data)]))
        return tff.federated_aggregate(locally_trained_models,
                                       make_zero_model_and_count(), accumulate,
                                       merge, report)

    federated_averaging_process = tff.templates.IterativeProcess(
        initialize_fn=create_zero_model_on_server, next_fn=train_one_round)

    compute_loss_on_one_batch = tff.experimental.jax_computation(
        _loss_fn, model_type, batch_type)

    return Trainer(create_initial_model=create_zero_model,
                   generate_random_batches=generate_random_batches,
                   train_on_one_batch=train_on_one_batch,
                   train_on_one_client=train_on_one_client,
                   local_training_process=local_training_process,
                   train_one_round=train_one_round,
                   federated_averaging_process=federated_averaging_process,
                   compute_loss_on_one_batch=compute_loss_on_one_batch)
Beispiel #8
0
def build_fed_avg_process(
    model_fn: ModelBuilder,
    client_optimizer_fn: OptimizerBuilder,
    client_lr: Union[float, LRScheduleFn] = 0.1,
    server_optimizer_fn: OptimizerBuilder = tf.keras.optimizers.SGD,
    server_lr: Union[float, LRScheduleFn] = 1.0,
    client_weight_fn: Optional[ClientWeightFn] = None,
) -> tff.templates.IterativeProcess:
  """Builds the TFF computations for optimization using federated averaging.

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`.
    client_optimizer_fn: A function that accepts a `learning_rate` keyword
      argument and returns a `tf.keras.optimizers.Optimizer` instance.
    client_lr: A scalar learning rate or a function that accepts a float
      `round_num` argument and returns a learning rate.
    server_optimizer_fn: A function that accepts a `learning_rate` argument and
      returns a `tf.keras.optimizers.Optimizer` instance.
    server_lr: A scalar learning rate or a function that accepts a float
      `round_num` argument and returns a learning rate.
    client_weight_fn: Optional function that takes the output of
      `model.report_local_outputs` and returns a tensor that provides the weight
      in the federated average of model deltas. If not provided, the default is
      the total number of examples processed on device.

  Returns:
    A `tff.templates.IterativeProcess`.
  """

  client_lr_schedule = client_lr
  if not callable(client_lr_schedule):
    client_lr_schedule = lambda round_num: client_lr

  server_lr_schedule = server_lr
  if not callable(server_lr_schedule):
    server_lr_schedule = lambda round_num: server_lr

  dummy_model = model_fn()

  server_init_tf = build_server_init_fn(
      model_fn,
      # Initialize with the learning rate for round zero.
      lambda: server_optimizer_fn(server_lr_schedule(0)))
  server_state_type = server_init_tf.type_signature.result
  model_weights_type = server_state_type.model
  round_num_type = server_state_type.round_num

  tf_dataset_type = tff.SequenceType(dummy_model.input_spec)
  model_input_type = tff.SequenceType(dummy_model.input_spec)

  @tff.tf_computation(model_input_type, model_weights_type, round_num_type)
  def client_update_fn(tf_dataset, initial_model_weights, round_num):
    client_lr = client_lr_schedule(round_num)
    client_optimizer = client_optimizer_fn(client_lr)
    client_update = create_client_update_fn()
    return client_update(model_fn(), tf_dataset, initial_model_weights,
                         client_optimizer, client_weight_fn)

  @tff.tf_computation(server_state_type, model_weights_type.trainable)
  def server_update_fn(server_state, model_delta):
    model = model_fn()
    server_lr = server_lr_schedule(server_state.round_num)
    server_optimizer = server_optimizer_fn(server_lr)
    # We initialize the server optimizer variables to avoid creating them
    # within the scope of the tf.function server_update.
    _initialize_optimizer_vars(model, server_optimizer)
    return server_update(model, server_optimizer, server_state, model_delta)

  @tff.federated_computation(
      tff.FederatedType(server_state_type, tff.SERVER),
      tff.FederatedType(tf_dataset_type, tff.CLIENTS))
  def run_one_round(server_state, federated_dataset):
    """Orchestration logic for one round of computation.

    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`.

    Returns:
      A tuple of updated `ServerState` and the result of
      `tff.learning.Model.federated_output_computation`.
    """
    client_model = tff.federated_broadcast(server_state.model)
    client_round_num = tff.federated_broadcast(server_state.round_num)

    client_outputs = tff.federated_map(
        client_update_fn,
        (federated_dataset, client_model, client_round_num))

    client_weight = client_outputs.client_weight
    model_delta = tff.federated_mean(
        client_outputs.weights_delta, weight=client_weight)

    server_state = tff.federated_map(server_update_fn,
                                     (server_state, model_delta))

    aggregated_outputs = dummy_model.federated_output_computation(
        client_outputs.model_output)
    if aggregated_outputs.type_signature.is_struct():
      aggregated_outputs = tff.federated_zip(aggregated_outputs)

    return server_state, aggregated_outputs

  @tff.federated_computation
  def initialize_fn():
    return tff.federated_value(server_init_tf(), tff.SERVER)

  return tff.templates.IterativeProcess(
      initialize_fn=initialize_fn, next_fn=run_one_round)
Beispiel #9
0
def build_fed_avg_process(
    model_fn: ModelBuilder,
    client_optimizer_fn: OptimizerBuilder,
    client_lr: Union[float, LRScheduleFn] = 0.1,
    server_optimizer_fn: OptimizerBuilder = tf.keras.optimizers.SGD,
    server_lr: Union[float, LRScheduleFn] = 1.0,
    client_weight_fn: Optional[ClientWeightFn] = None,
    dataset_preprocess_comp: Optional[tff.Computation] = None,
) -> FederatedAveragingProcessAdapter:
    """Builds the TFF computations for optimization using federated averaging.

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`.
    client_optimizer_fn: A function that accepts a `learning_rate` keyword
      argument and returns a `tf.keras.optimizers.Optimizer` instance.
    client_lr: A scalar learning rate or a function that accepts a float
      `round_num` argument and returns a learning rate.
    server_optimizer_fn: A function that accepts a `learning_rate` argument and
      returns a `tf.keras.optimizers.Optimizer` instance.
    server_lr: A scalar learning rate or a function that accepts a float
      `round_num` argument and returns a learning rate.
    client_weight_fn: Optional function that takes the output of
      `model.report_local_outputs` and returns a tensor that provides the weight
      in the federated average of model deltas. If not provided, the default is
      the total number of examples processed on device.
    dataset_preprocess_comp: Optional `tff.Computation` that sets up a data
      pipeline on the clients. The computation must take a squence of values
      and return a sequence of values, or in TFF type shorthand `(U* -> V*)`. If
      `None`, no dataset preprocessing is applied.

  Returns:
    A `FederatedAveragingProcessAdapter`.
  """

    client_lr_schedule = client_lr
    if not callable(client_lr_schedule):
        client_lr_schedule = lambda round_num: client_lr

    server_lr_schedule = server_lr
    if not callable(server_lr_schedule):
        server_lr_schedule = lambda round_num: server_lr

    dummy_model = model_fn()

    server_init_tf = build_server_init_fn(
        model_fn,
        # Initialize with the learning rate for round zero.
        lambda: server_optimizer_fn(server_lr_schedule(0)))
    server_state_type = server_init_tf.type_signature.result
    model_weights_type = server_state_type.model
    round_num_type = server_state_type.round_num

    if dataset_preprocess_comp is not None:
        tf_dataset_type = dataset_preprocess_comp.type_signature.parameter
        model_input_type = tff.SequenceType(dummy_model.input_spec)
        preprocessed_dataset_type = dataset_preprocess_comp.type_signature.result
        if not model_input_type.is_assignable_from(preprocessed_dataset_type):
            raise TypeError(
                'Supplied `dataset_preprocess_comp` does not yield '
                'batches that are compatible with the model constructed '
                'by `model_fn`. Model expects type {m}, but dataset '
                'yields type {d}.'.format(m=model_input_type,
                                          d=preprocessed_dataset_type))
    else:
        tf_dataset_type = tff.SequenceType(dummy_model.input_spec)
        model_input_type = tff.SequenceType(dummy_model.input_spec)

    @tff.tf_computation(model_input_type, model_weights_type, round_num_type)
    def client_update_fn(tf_dataset, initial_model_weights, round_num):
        client_lr = client_lr_schedule(round_num)
        client_optimizer = client_optimizer_fn(client_lr)
        client_update = create_client_update_fn()
        return client_update(model_fn(), tf_dataset, initial_model_weights,
                             client_optimizer, client_weight_fn)

    @tff.tf_computation(server_state_type, model_weights_type.trainable)
    def server_update_fn(server_state, model_delta):
        model = model_fn()
        server_lr = server_lr_schedule(server_state.round_num)
        server_optimizer = server_optimizer_fn(server_lr)
        # We initialize the server optimizer variables to avoid creating them
        # within the scope of the tf.function server_update.
        _initialize_optimizer_vars(model, server_optimizer)
        return server_update(model, server_optimizer, server_state,
                             model_delta)

    @tff.federated_computation(tff.FederatedType(server_state_type,
                                                 tff.SERVER),
                               tff.FederatedType(tf_dataset_type, tff.CLIENTS))
    def run_one_round(server_state, federated_dataset):
        """Orchestration logic for one round of computation.

    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`.

    Returns:
      A tuple of updated `ServerState` and the result of
      `tff.learning.Model.federated_output_computation`.
    """
        client_model = tff.federated_broadcast(server_state.model)
        client_round_num = tff.federated_broadcast(server_state.round_num)
        if dataset_preprocess_comp is not None:
            federated_dataset = tff.federated_map(dataset_preprocess_comp,
                                                  federated_dataset)
        client_outputs = tff.federated_map(
            client_update_fn,
            (federated_dataset, client_model, client_round_num))

        client_weight = client_outputs.client_weight
        model_delta = tff.federated_mean(client_outputs.weights_delta,
                                         weight=client_weight)

        server_state = tff.federated_map(server_update_fn,
                                         (server_state, model_delta))

        aggregated_outputs = dummy_model.federated_output_computation(
            client_outputs.model_output)
        if aggregated_outputs.type_signature.is_tuple():
            aggregated_outputs = tff.federated_zip(aggregated_outputs)

        return server_state, aggregated_outputs

    @tff.federated_computation
    def initialize_fn():
        return tff.federated_value(server_init_tf(), tff.SERVER)

    tff_iterative_process = tff.templates.IterativeProcess(
        initialize_fn=initialize_fn, next_fn=run_one_round)

    return FederatedAveragingProcessAdapter(tff_iterative_process)
Beispiel #10
0
import collections
import time
import numpy as np
import grpc
import sys
import absl

import tensorflow as tf
import tensorflow_federated as tff

import nest_asyncio
nest_asyncio.apply()

@tff.tf_computation(tf.int64)
@tf.function
def add_one(n):
    tf.print("Hello: ", n, output_stream=absl.logging.info)
    return tf.add(n, 1)


@tff.federated_computation(tff.FederatedType(tf.int64, tff.CLIENTS))
def add_one_on_clients(federated_n):
    return tff.federated_map(add_one, federated_n)



print(add_one_on_clients([1]))
Beispiel #11
0
    def __attrs_post_init__(self):
        self.gen_input_type = tensor_spec_for_batch(self.dummy_gen_input)
        self.real_data_type = tensor_spec_for_batch(self.dummy_real_data)

        # Model-weights based types
        self._generator = self.generator_model_fn()
        _ = self._generator(self.dummy_gen_input)
        py_typecheck.check_type(self._generator, tf.keras.models.Model)
        self._discriminator = self.discriminator_model_fn()
        _ = self._discriminator(self.dummy_real_data)
        py_typecheck.check_type(self._discriminator, tf.keras.models.Model)
        self._state_gen_opt = self.state_gen_optimizer_fn(1)
        self._state_disc_opt = self.state_disc_optimizer_fn(1)
        gan_training_tf_fns.initialize_optimizer_vars(self._generator,
                                                      self._state_gen_opt)
        gan_training_tf_fns.initialize_optimizer_vars(self._discriminator,
                                                      self._state_disc_opt)
        self._counters = collections.OrderedDict({
            'num_discriminator_train_examples':
            tf.constant(0),
            'num_generator_train_examples':
            tf.constant(0),
            'num_rounds':
            tf.constant(0),
        })

        def vars_to_type(var_struct):
            # TODO(b/131681951): read_value() shouldn't be needed
            return tf.nest.map_structure(
                lambda v: tf.TensorSpec.from_tensor(
                    tf.cast(v.read_value(), tf.float32)), var_struct)

        def vars_to_type_counter(var_struct):
            # TODO(b/131681951): read_value() shouldn't be needed
            return tf.nest.map_structure(
                lambda v: tf.TensorSpec.from_tensor(v), var_struct)

        self.discriminator_weights_type = vars_to_type(
            gan_training_tf_fns._weights(self._discriminator))
        self.generator_weights_type = vars_to_type(
            gan_training_tf_fns._weights(self._generator))
        self.state_gen_opt_weights_type = vars_to_type(
            self._state_gen_opt.variables())
        self.state_disc_opt_weights_type = vars_to_type(
            self._state_disc_opt.variables())
        self.counters_type = vars_to_type_counter(self._counters)
        self.from_server_type = gan_training_tf_fns.FromServer(
            generator_weights=self.generator_weights_type,
            discriminator_weights=self.discriminator_weights_type,
            state_gen_optimizer_weights=self.state_gen_opt_weights_type,
            state_disc_optimizer_weights=self.state_disc_opt_weights_type,
            counters=self.counters_type)
        self.client_gen_input_type = tff.FederatedType(
            tff.SequenceType(self.gen_input_type), tff.CLIENTS)
        self.client_real_data_type = tff.FederatedType(
            tff.SequenceType(self.real_data_type), tff.CLIENTS)
        self.server_gen_input_type = tff.FederatedType(
            tff.SequenceType(self.gen_input_type), tff.SERVER)

        # Right now, the logic in this library is effectively "if DP use stateful
        # aggregator, else don't use stateful aggregator". An alternative
        # formulation would be to always use a stateful aggregator, but when not
        # using DP default the aggregator to be a stateless mean, e.g.,
        # https://github.com/tensorflow/federated/blob/master/tensorflow_federated/python/learning/framework/optimizer_utils.py#L283.
        # This change will be easier to make if the tff.StatefulAggregateFn is
        # modified to have a property that gives the type of the aggregation state
        # (i.e., what we're storing in self.dp_averaging_state_type).
        if self.train_discriminator_dp_average_query is not None:
            self.dp_averaging_fn, self.dp_averaging_state_type = (
                tff.utils.build_dp_aggregate(
                    query=self.train_discriminator_dp_average_query,
                    value_type_fn=lambda value: self.discriminator_weights_type
                ))
Beispiel #12
0
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Simple temperature sensor example in TFF."""

import numpy as np
import tensorflow as tf
import tensorflow_federated as tff


@tff.tf_computation(tff.SequenceType(tf.float32), tf.float32)
def count_over(ds, t):
    return ds.reduce(np.float32(0),
                     lambda n, x: n + tf.cast(tf.greater(x, t), tf.float32))


@tff.tf_computation(tff.SequenceType(tf.float32))
def count_total(ds):
    return ds.reduce(np.float32(0.0), lambda n, _: n + 1.0)


@tff.federated_computation(
    tff.FederatedType(tff.SequenceType(tf.float32), tff.CLIENTS),
    tff.FederatedType(tf.float32, tff.SERVER))
def mean_over_threshold(temperatures, threshold):
    client_data = tff.federated_broadcast(threshold)
    client_data = tff.federated_zip([temperatures, client_data])
    result_map = tff.federated_map(count_over, client_data)
    count_map = tff.federated_map(count_total, temperatures)
    return tff.federated_mean(result_map, count_map)
Beispiel #13
0
def build_jax_federated_averaging_process(batch_type, model_type, loss_fn,
                                          step_size):
  """Constructs an iterative process that implements simple federated averaging.

  Args:
    batch_type: An instance of `tff.Type` that represents the type of a single
      batch of data to use for training. This type should be constructed with
      standard Python containers (such as `collections.OrderedDict`) of the sort
      that are expected as parameters to `loss_fn`.
    model_type: An instance of `tff.Type` that represents the type of the model.
      Similarly to `batch_size`, this type should be constructed with standard
      Python containers (such as `collections.OrderedDict`) of the sort that are
      expected as parameters to `loss_fn`.
    loss_fn: A loss function for the model. Must be a Python function that takes
      two parameters, one of them being the model, and the other being a single
      batch of data (with types matching `batch_type` and `model_type`).
    step_size: The step size to use during training (an `np.float32`).

  Returns:
    An instance of `tff.templates.IterativeProcess` that implements federated
    training in JAX.
  """
  batch_type = tff.to_type(batch_type)
  model_type = tff.to_type(model_type)

  # py_typecheck.check_type(batch_type, computation_types.Type)
  # py_typecheck.check_type(model_type, computation_types.Type)
  # py_typecheck.check_callable(loss_fn)
  # py_typecheck.check_type(step_size, np.float)

  def _tensor_zeros(tensor_type):
    return jax.numpy.zeros(
        tensor_type.shape.dims, dtype=tensor_type.dtype.as_numpy_dtype)

  @tff.jax_computation
  def _create_zero_model():
    model_zeros = tff.structure.map_structure(_tensor_zeros, model_type)
    return tff.types.type_to_py_container(model_zeros, model_type)

  @tff.federated_computation
  def _create_zero_model_on_server():
    return tff.federated_eval(_create_zero_model, tff.SERVER)

  def _apply_update(model_param, param_delta):
    return model_param - step_size * param_delta

  @tff.jax_computation(model_type, batch_type)
  def _train_on_one_batch(model, batch):
    params = tff.structure.flatten(
        tff.structure.from_container(model, recursive=True))
    grads = tff.structure.flatten(
        tff.structure.from_container(jax.grad(loss_fn)(model, batch)))
    updated_params = [_apply_update(x, y) for (x, y) in zip(params, grads)]
    trained_model = tff.structure.pack_sequence_as(model_type, updated_params)
    return tff.types.type_to_py_container(trained_model, model_type)

  local_dataset_type = tff.SequenceType(batch_type)

  @tff.federated_computation(model_type, local_dataset_type)
  def _train_on_one_client(model, batches):
    return tff.sequence_reduce(batches, model, _train_on_one_batch)

  @tff.federated_computation(
      tff.FederatedType(model_type, tff.SERVER),
      tff.FederatedType(local_dataset_type, tff.CLIENTS))
  def _train_one_round(model, federated_data):
    locally_trained_models = tff.federated_map(
        _train_on_one_client,
        collections.OrderedDict([('model', tff.federated_broadcast(model)),
                                 ('batches', federated_data)]))
    return tff.federated_mean(locally_trained_models)

  return tff.templates.IterativeProcess(
      initialize_fn=_create_zero_model_on_server, next_fn=_train_one_round)
Beispiel #14
0
def build_federated_averaging_process(
    model_fn,
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0),
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.1)):
    """Builds the TFF computations for optimization using federated averaging.

  Args:
    model_fn: A no-arg function that returns a `tff.learning.TrainableModel`.
    server_optimizer_fn: A no-arg function that returns a
      `tf.keras.optimizers.Optimizer` for server update.
    client_optimizer_fn: A no-arg function that returns a
      `tf.keras.optimizers.Optimizer` for client update.

  Returns:
    A `tff.utils.IterativeProcess`.
  """

    dummy_model = model_fn(
    )  # TODO(b/144510813): try remove dependency on dummy model

    @tff.tf_computation
    def server_init_tf():
        model = model_fn()
        server_optimizer = server_optimizer_fn()
        _initialize_optimizer_vars(model, server_optimizer)
        return ServerState(model=model.weights,
                           optimizer_state=server_optimizer.variables())

    server_state_type = server_init_tf.type_signature.result
    model_weights_type = server_state_type.model

    @tff.tf_computation(server_state_type, model_weights_type.trainable)
    def server_update_fn(server_state, model_delta):
        model = model_fn()
        server_optimizer = server_optimizer_fn()
        _initialize_optimizer_vars(model, server_optimizer)
        return server_update(model, server_optimizer, server_state,
                             model_delta)

    tf_dataset_type = tff.SequenceType(dummy_model.input_spec)

    @tff.tf_computation(tf_dataset_type, model_weights_type)
    def client_update_fn(tf_dataset, initial_model_weights):
        model = model_fn()
        client_optimizer = client_optimizer_fn()
        return client_update(model, tf_dataset, initial_model_weights,
                             client_optimizer)

    federated_server_state_type = tff.FederatedType(server_state_type,
                                                    tff.SERVER)
    federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS)

    @tff.federated_computation(federated_server_state_type,
                               federated_dataset_type)
    def run_one_round(server_state, federated_dataset):
        """Orchestration logic for one round of computation.

    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`.

    Returns:
      A tuple of updated `ServerState` and the result of
      `tff.learning.Model.federated_output_computation`.
    """
        client_model = tff.federated_broadcast(server_state.model)

        client_outputs = tff.federated_map(client_update_fn,
                                           (federated_dataset, client_model))

        weight_denom = client_outputs.client_weight
        round_model_delta = tff.federated_mean(client_outputs.weights_delta,
                                               weight=weight_denom)

        server_state = tff.federated_map(server_update_fn,
                                         (server_state, round_model_delta))

        aggregated_outputs = dummy_model.federated_output_computation(
            client_outputs.model_output)
        aggregated_outputs = tff.federated_zip(aggregated_outputs)

        return server_state, aggregated_outputs

    return tff.utils.IterativeProcess(initialize_fn=tff.federated_computation(
        lambda: tff.federated_value(server_init_tf(), tff.SERVER)),
                                      next_fn=run_one_round)
Beispiel #15
0
def build_fed_avg_process(
    model_fn: ModelBuilder,
    client_optimizer_fn: OptimizerBuilder,
    client_lr: Union[float, LRScheduleFn] = 0.1,
    server_optimizer_fn: OptimizerBuilder = tf.keras.optimizers.SGD,
    server_lr: Union[float, LRScheduleFn] = 1.0,
    aggregation_process: Optional[tff.templates.MeasuredProcess] = None,
) -> tff.templates.IterativeProcess:
    """Builds the TFF computations for optimization using federated averaging.

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`.
    client_optimizer_fn: A function that accepts a `learning_rate` keyword
      argument and returns a `tf.keras.optimizers.Optimizer` instance.
    client_lr: A scalar learning rate or a function that accepts a float
      `round_num` argument and returns a learning rate.
    server_optimizer_fn: A function that accepts a `learning_rate` argument and
      returns a `tf.keras.optimizers.Optimizer` instance.
    server_lr: A scalar learning rate or a function that accepts a float
      `round_num` argument and returns a learning rate.
    client_weight_fn: Optional function that takes the output of
      `model.report_local_outputs` and returns a tensor that provides the weight
      in the federated average of model deltas. If not provided, the default is
      the total number of examples processed on device.

  Returns:
    A `tff.templates.IterativeProcess`.
  """

    client_lr_schedule = client_lr
    if not callable(client_lr_schedule):
        client_lr_schedule = lambda round_num: client_lr

    server_lr_schedule = server_lr
    if not callable(server_lr_schedule):
        server_lr_schedule = lambda round_num: server_lr

    with tf.Graph().as_default():
        dummy_model = model_fn()
        model_weights_type = model_utils.weights_type_from_model(dummy_model)
        dummy_optimizer = server_optimizer_fn()
        _initialize_optimizer_vars(dummy_model, dummy_optimizer)
        optimizer_variable_type = tff.framework.type_from_tensors(
            dummy_optimizer.variables())

    initialize_computation = build_server_init_fn(
        model_fn=model_fn,
        # Initialize with the learning rate for round zero.
        server_optimizer_fn=lambda: server_optimizer_fn(server_lr_schedule(0)),
        aggregation_process=aggregation_process)
    #model_weights_type = tff.framework.type_from_tensors(_get_weights(dummy_model).trainable)
    round_num_type = tf.float32

    tf_dataset_type = tff.SequenceType(dummy_model.input_spec)
    model_input_type = tff.SequenceType(dummy_model.input_spec)
    client_weight_type = tf.float32

    aggregation_state_type = aggregation_process.initialize.type_signature.result.member

    server_state_type = ServerState(
        model=model_weights_type,
        optimizer_state=optimizer_variable_type,
        round_num=round_num_type,
        aggregation_state=aggregation_state_type,
    )

    @tff.tf_computation(model_input_type, model_weights_type, round_num_type)
    def client_update_fn(tf_dataset, initial_model_weights, round_num):
        client_lr = client_lr_schedule(round_num)
        client_optimizer = client_optimizer_fn(client_lr)
        client_update = create_client_update_fn()
        return client_update(model_fn(), tf_dataset, initial_model_weights,
                             client_optimizer)

    @tff.tf_computation(server_state_type, model_weights_type.trainable)
    def server_update_fn(server_state, model_delta):
        model = model_fn()
        server_lr = server_lr_schedule(server_state.round_num)
        server_optimizer = server_optimizer_fn(server_lr)
        # We initialize the server optimizer variables to avoid creating them
        # within the scope of the tf.function server_update.
        _initialize_optimizer_vars(model, server_optimizer)
        return server_update(model, server_optimizer, server_state,
                             model_delta)

    # @tff.tf_computation(tf.float32, tf.float32)
    # def local_mul(weight, participated):
    #   return tf.math.multiply(weight, participated)

    @tff.federated_computation(tff.FederatedType(server_state_type,
                                                 tff.SERVER),
                               tff.FederatedType(tf_dataset_type, tff.CLIENTS),
                               tff.FederatedType(client_weight_type,
                                                 tff.CLIENTS))
    def run_one_round(server_state, federated_dataset, client_weight):
        """Orchestration logic for one round of computation.

    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`.

    Returns:
      A tuple of updated `ServerState` and the result of
      `tff.learning.Model.federated_output_computation`.
    """
        client_model = tff.federated_broadcast(server_state.model)
        client_round_num = tff.federated_broadcast(server_state.round_num)

        client_outputs = tff.federated_map(
            client_update_fn,
            (federated_dataset, client_model, client_round_num))

        #client_weight = client_outputs.client_weight
        # model_delta = tff.federated_mean(
        #     client_outputs.weights_delta, weight=client_weight)

        participant_client_weight = tff.federated_map(
            tff.tf_computation(lambda x, y: x * y),
            (client_weight, client_outputs.client_weight))

        aggregation_output = aggregation_process.next(
            server_state.aggregation_state, client_outputs.weights_delta,
            participant_client_weight)

        server_state = tff.federated_map(
            server_update_fn, (server_state, aggregation_output.result))

        aggregated_outputs = dummy_model.federated_output_computation(
            client_outputs.model_output)
        if aggregated_outputs.type_signature.is_struct():
            aggregated_outputs = tff.federated_zip(aggregated_outputs)

        return server_state, aggregated_outputs

    # @tff.federated_computation
    # def initialize_fn():
    #   return tff.federated_value(server_init_tf(), tff.SERVER)

    return tff.templates.IterativeProcess(initialize_fn=initialize_computation,
                                          next_fn=run_one_round)
Beispiel #16
0
    def test_build_with_preprocess_funtion(self):
        test_dataset = tf.data.Dataset.range(5)
        client_datasets_type = tff.FederatedType(
            tff.SequenceType(test_dataset.element_spec), tff.CLIENTS)

        @tff.tf_computation(tff.SequenceType(test_dataset.element_spec))
        def preprocess_dataset(ds):
            def to_batch(x):
                return collections.OrderedDict(x=[float(x) * 1.0],
                                               y=[float(x) * 3.0 + 1.0])

            return ds.map(to_batch).repeat().batch(2).take(3)

        client_lr_callback = callbacks.create_reduce_lr_on_plateau(
            learning_rate=0.1,
            min_delta=0.5,
            window_size=2,
            decay_factor=1.0,
            cooldown=0)
        server_lr_callback = callbacks.create_reduce_lr_on_plateau(
            learning_rate=0.1,
            min_delta=0.5,
            window_size=2,
            decay_factor=1.0,
            cooldown=0)

        iterative_process = adaptive_fed_avg.build_fed_avg_process(
            _uncompiled_model_builder,
            client_lr_callback,
            server_lr_callback,
            client_optimizer_fn=tf.keras.optimizers.SGD,
            server_optimizer_fn=tf.keras.optimizers.SGD,
            dataset_preprocess_comp=preprocess_dataset)

        lr_callback_type = tff.framework.type_from_tensors(client_lr_callback)

        server_state_type = tff.FederatedType(
            adaptive_fed_avg.ServerState(model=tff.learning.ModelWeights(
                trainable=(tff.TensorType(tf.float32, [1, 1]),
                           tff.TensorType(tf.float32, [1])),
                non_trainable=()),
                                         optimizer_state=[tf.int64],
                                         client_lr_callback=lr_callback_type,
                                         server_lr_callback=lr_callback_type),
            tff.SERVER)

        self.assertEqual(
            iterative_process.initialize.type_signature,
            tff.FunctionType(parameter=None, result=server_state_type))

        metrics_type = tff.FederatedType(
            collections.OrderedDict(loss=tff.TensorType(tf.float32)),
            tff.SERVER)
        output_type = collections.OrderedDict(before_training=metrics_type,
                                              during_training=metrics_type)
        expected_result_type = (server_state_type, output_type)

        expected_type = tff.FunctionType(parameter=collections.OrderedDict(
            server_state=server_state_type,
            federated_dataset=client_datasets_type),
                                         result=expected_result_type)

        actual_type = iterative_process.next.type_signature
        self.assertEqual(actual_type,
                         expected_type,
                         msg='{s}\n!={t}'.format(s=actual_type,
                                                 t=expected_type))
Beispiel #17
0
 async def _encrypt_values_on_singleton(self, val, sender, receiver):
     ###
     # we can safely assume  sender has cardinality=1 when receiver is CLIENTS
     ###
     # Case 1: receiver=CLIENTS
     #     plaintext: Fed(Tensor, sender, all_equal=True)
     #     pk_receiver: Fed(Tuple(Tensor), sender, all_equal=True)
     #     sk_sender: Fed(Tensor, sender, all_equal=True)
     #   Returns:
     #     encrypted_values: Tuple(Fed(Tensor, sender, all_equal=True))
     ###
     ### Check proper key placement
     sk_sender = self.key_references.get_secret_key(sender)
     pk_receiver = self.key_references.get_public_key(receiver)
     type_analysis.check_federated_type(sk_sender.type_signature,
                                        placement=sender)
     assert sk_sender.type_signature.placement is sender
     assert pk_receiver.type_signature.placement is sender
     ### Check placement cardinalities
     rcv_children = self.strategy._get_child_executors(receiver)
     snd_children = self.strategy._get_child_executors(sender)
     py_typecheck.check_len(snd_children, 1)
     snd_child = snd_children[0]
     ### Check value cardinalities
     type_analysis.check_federated_type(val.type_signature,
                                        placement=sender)
     py_typecheck.check_len(val.internal_representation, 1)
     py_typecheck.check_type(pk_receiver.type_signature.member,
                             tff.StructType)
     py_typecheck.check_len(pk_receiver.internal_representation,
                            len(rcv_children))
     py_typecheck.check_len(sk_sender.internal_representation, 1)
     ### Materialize encryptor function definition & type spec
     input_type = val.type_signature.member
     self._input_type_cache = input_type
     pk_rcv_type = pk_receiver.type_signature.member
     sk_snd_type = sk_sender.type_signature.member
     pk_element_type = pk_rcv_type[0]
     encryptor_arg_spec = (input_type, pk_element_type, sk_snd_type)
     encryptor_proto, encryptor_type = utils.materialize_computation_from_cache(
         sodium_comp.make_encryptor, self._encryptor_cache,
         encryptor_arg_spec)
     ### Prepare encryption arguments
     v = val.internal_representation[0]
     sk = sk_sender.internal_representation[0]
     ### Encrypt values and return them
     encryptor_fn = await snd_child.create_value(encryptor_proto,
                                                 encryptor_type)
     encryptor_args = await asyncio.gather(*[
         snd_child.create_struct([v, this_pk, sk])
         for this_pk in pk_receiver.internal_representation
     ])
     encrypted_values = await asyncio.gather(*[
         snd_child.create_call(encryptor_fn, arg) for arg in encryptor_args
     ])
     encrypted_value_types = [encryptor_type.result] * len(encrypted_values)
     return federated_resolving_strategy.FederatedResolvingStrategyValue(
         structure.from_container(encrypted_values),
         tff.StructType([
             tff.FederatedType(evt, sender, all_equal=False)
             for evt in encrypted_value_types
         ]))
def build_fed_avg_process(model_fn,
                          client_lr_callback,
                          client_callback_update_fn,
                          server_lr_callback,
                          server_callback_update_fn,
                          client_optimizer_fn=tf.keras.optimizers.SGD,
                          server_optimizer_fn=tf.keras.optimizers.SGD,
                          client_weight_fn=None):
    """Builds the TFF computations for FedAvg with learning rate decay.

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`.
    client_lr_callback: A `ReduceLROnPlateau` callback.
    client_callback_update_fn: A function that updates the client callback.
    server_lr_callback: A `ReduceLROnPlateau` callback.
    server_callback_update_fn: A function that updates the server callback.
    client_optimizer_fn: A function that accepts a `learning_rate` keyword
      argument and returns a `tf.keras.optimizers.Optimizer` instance.
    server_optimizer_fn: A function that accepts a `learning_rate` argument and
      returns a `tf.keras.optimizers.Optimizer` instance.
    client_weight_fn: Optional function that takes the output of
      `model.report_local_outputs` and returns a tensor that provides the weight
      in the federated average of model deltas. If not provided, the default is
      the total number of examples processed on device.

  Returns:
    A `tff.templates.IterativeProcess`.
  """

    dummy_model = model_fn()
    client_monitor = client_lr_callback.monitor
    server_monitor = server_lr_callback.monitor

    server_init_tf = build_server_init_fn(model_fn, server_optimizer_fn,
                                          client_lr_callback,
                                          server_lr_callback)

    server_state_type = server_init_tf.type_signature.result
    model_weights_type = server_state_type.model
    tf_dataset_type = tff.SequenceType(dummy_model.input_spec)

    client_lr_type = server_state_type.client_lr_callback.learning_rate
    client_monitor_value_type = server_state_type.client_lr_callback.best
    server_monitor_value_type = server_state_type.server_lr_callback.best

    @tff.tf_computation(tf_dataset_type, model_weights_type, client_lr_type)
    def client_update_fn(tf_dataset, initial_model_weights, client_lr):
        client_optimizer = client_optimizer_fn(learning_rate=client_lr)
        initial_model_output = get_client_output(model_fn(), tf_dataset,
                                                 initial_model_weights)
        client_state = client_update(model_fn(), tf_dataset,
                                     initial_model_weights, client_optimizer,
                                     client_weight_fn)
        return tff.utils.update_state(
            client_state, initial_model_output=initial_model_output)

    @tff.tf_computation(server_state_type, model_weights_type.trainable,
                        client_monitor_value_type, server_monitor_value_type)
    def server_update_fn(server_state, model_delta, client_monitor_value,
                         server_monitor_value):
        model = model_fn()
        server_lr = server_state.server_lr_callback.learning_rate
        server_optimizer = server_optimizer_fn(learning_rate=server_lr)
        # We initialize the server optimizer variables to avoid creating them
        # within the scope of the tf.function server_update.
        _initialize_optimizer_vars(model, server_optimizer)
        return server_update(model, server_optimizer, server_state,
                             model_delta, client_monitor_value,
                             client_callback_update_fn, server_monitor_value,
                             server_callback_update_fn)

    @tff.federated_computation(tff.FederatedType(server_state_type,
                                                 tff.SERVER),
                               tff.FederatedType(tf_dataset_type, tff.CLIENTS))
    def run_one_round(server_state, federated_dataset):
        """Orchestration logic for one round of computation.

    Note that in addition to updating the server weights according to the client
    model weight deltas, we extract metrics (governed by the `monitor` attribute
    of the `client_lr_callback` and `server_lr_callback` attributes of the
    `server_state`) and use these to update the client learning rate callbacks.

    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`.

    Returns:
      A tuple of updated `ServerState` and the result of
      `tff.learning.Model.federated_output_computation` before and during local
      client training.
    """
        client_model = tff.federated_broadcast(server_state.model)
        client_lr = tff.federated_broadcast(
            server_state.client_lr_callback.learning_rate)
        client_outputs = tff.federated_map(
            client_update_fn, (federated_dataset, client_model, client_lr))

        client_weight = client_outputs.client_weight
        aggregated_gradients = tff.federated_mean(
            client_outputs.accumulated_gradients, weight=client_weight)

        initial_aggregated_outputs = dummy_model.federated_output_computation(
            client_outputs.initial_model_output)
        if isinstance(initial_aggregated_outputs.type_signature,
                      tff.StructType):
            initial_aggregated_outputs = tff.federated_zip(
                initial_aggregated_outputs)

        aggregated_outputs = dummy_model.federated_output_computation(
            client_outputs.model_output)
        if isinstance(aggregated_outputs.type_signature, tff.StructType):
            aggregated_outputs = tff.federated_zip(aggregated_outputs)
        client_monitor_value = initial_aggregated_outputs[client_monitor]
        server_monitor_value = initial_aggregated_outputs[server_monitor]

        server_state = tff.federated_map(
            server_update_fn, (server_state, aggregated_gradients,
                               client_monitor_value, server_monitor_value))

        result = collections.OrderedDict(
            before_training=initial_aggregated_outputs,
            during_training=aggregated_outputs)

        return server_state, result

    @tff.federated_computation
    def initialize_fn():
        return tff.federated_value(server_init_tf(), tff.SERVER)

    return tff.templates.IterativeProcess(initialize_fn=initialize_fn,
                                          next_fn=run_one_round)
Beispiel #19
0
    def __init__(
        self,
        model_fn,
        m,
        n,
        j_max,
        importance_sampling,
        server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0),
        client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.1),
    ):
        """Builds the TFF computations for optimization using federated averaging.
        Args:
        model_fn: A no-arg function that returns a
          `simple_fedavg_tf.KerasModelWrapper`.
        server_optimizer_fn: A no-arg function that returns a
          `tf.keras.optimizers.Optimizer` for server update.
        client_optimizer_fn: A no-arg function that returns a
          `tf.keras.optimizers.Optimizer` for client update.
        Returns:
        A `tff.templates.IterativeProcess`.
        """

        dummy_model = model_fn()

        @tff.tf_computation
        def server_init_tf():
            model = model_fn()
            server_optimizer = server_optimizer_fn()
            _initialize_optimizer_vars(model, server_optimizer)
            return ServerState(model_weights=model.weights,
                               optimizer_state=server_optimizer.variables(),
                               round_num=0)

        server_state_type = server_init_tf.type_signature.result

        model_weights_type = server_state_type.model_weights

        @tff.tf_computation(server_state_type, model_weights_type.trainable)
        def server_update_fn(server_state, model_delta):
            model = model_fn()
            server_optimizer = server_optimizer_fn()
            _initialize_optimizer_vars(model, server_optimizer)
            return server_update(model, server_optimizer, server_state,
                                 model_delta)

        @tff.tf_computation(server_state_type)
        def server_message_fn(server_state):
            return build_server_broadcast_message(server_state)

        server_message_type = server_message_fn.type_signature.result
        tf_dataset_type = tff.SequenceType(dummy_model.input_spec)

        @tff.tf_computation(tf_dataset_type, server_message_type)
        def client_update_fn(tf_dataset, server_message):
            model = model_fn()
            client_optimizer = client_optimizer_fn()
            return client_update(model, tf_dataset, server_message,
                                 client_optimizer)

        federated_server_state_type = tff.FederatedType(
            server_state_type, tff.SERVER)
        federated_dataset_type = tff.FederatedType(tf_dataset_type,
                                                   tff.CLIENTS)

        @tff.tf_computation(
            tf.float32,
            tf.float32,
        )
        def scale(update_norm, sum_update_norms):
            if importance_sampling:
                return tf.minimum(
                    1., tf.divide(tf.multiply(update_norm, m),
                                  sum_update_norms))
            else:
                return tf.divide(m, n)

        @tff.federated_computation(tff.FederatedType(tf.float32, tff.CLIENTS),
                                   tff.FederatedType(tf.float32, tff.CLIENTS,
                                                     True))
        def scale_on_clients(update_norm, sum_update_norms):
            return tff.federated_map(scale, (update_norm, sum_update_norms))

        @tff.tf_computation(tf.float32)
        def create_prob_message(prob):
            def f1():
                return tf.stack([prob, 1.])

            def f2():
                return tf.constant([0., 0.])

            prob_message = tf.cond(tf.less(prob, 1), f1, f2)
            return prob_message

        @tff.federated_computation(tff.FederatedType(tf.float32, tff.CLIENTS))
        def create_prob_message_on_clients(prob):
            return tff.federated_map(create_prob_message, prob)

        @tff.tf_computation(tff.TensorType(tf.float32, (2, )))
        def compute_rescaling(prob_aggreg):
            rescaling_factor = (m - n + prob_aggreg[1]) / prob_aggreg[0]
            return rescaling_factor

        @tff.federated_computation(
            tff.FederatedType(tff.TensorType(tf.float32, (2, )), tff.SERVER))
        def compute_rescaling_on_master(prob_aggreg):
            return tff.federated_map(compute_rescaling, prob_aggreg)

        @tff.tf_computation(tf.float32, tf.float32)
        def rescale_prob(prob, rescaling_factor):
            return tf.minimum(1., tf.multiply(prob, rescaling_factor))

        @tff.federated_computation(tff.FederatedType(tf.float32, tff.CLIENTS),
                                   tff.FederatedType(tf.float32, tff.CLIENTS,
                                                     True))
        def rescale_prob_on_clients(rob, rescaling_factor):
            return tff.federated_map(rescale_prob, (rob, rescaling_factor))

        @tff.tf_computation(tf.float32)
        def compute_weights_is_fn(prob):
            def f1():
                return 1. / prob

            def f2():
                return 0.

            weight = tf.cond(tf.less(tf.random.uniform(()), prob), f1, f2)
            return weight

        @tff.federated_computation(tff.FederatedType(tf.float32, tff.CLIENTS))
        def compute_weights_is(prob):
            return tff.federated_map(compute_weights_is_fn, prob)

        @tff.federated_computation(
            tff.FederatedType(model_weights_type.trainable, tff.CLIENTS),
            tff.FederatedType(tf.float32, tff.CLIENTS))
        def compute_round_model_delta(weights_delta, weights_denom):
            return tff.federated_mean(weights_delta, weight=weights_denom)

        @tff.federated_computation(federated_server_state_type,
                                   tff.FederatedType(
                                       model_weights_type.trainable,
                                       tff.SERVER))
        def update_server_state(server_state, round_model_delta):
            return tff.federated_map(server_update_fn,
                                     (server_state, round_model_delta))

        @tff.federated_computation(tff.FederatedType(tf.float32, tff.CLIENTS),
                                   tff.FederatedType(tf.float32, tff.CLIENTS))
        def compute_loss_metric(model_output, weight_denom):
            return tff.federated_mean(model_output, weight=weight_denom)

        @tff.tf_computation(model_weights_type.trainable, tf.float32)
        def rescale_and_remove_fn(weights_delta, weights_is):
            return [
                tf.math.scalar_mul(weights_is, weights_layer_delta)
                for weights_layer_delta in weights_delta
            ]

        @tff.federated_computation(
            tff.FederatedType(model_weights_type.trainable, tff.CLIENTS),
            tff.FederatedType(tf.float32, tff.CLIENTS))
        def rescale_and_remove(weights_delta, weights_is):
            return tff.federated_map(rescale_and_remove_fn,
                                     (weights_delta, weights_is))

        @tff.federated_computation(federated_server_state_type,
                                   federated_dataset_type)
        def run_gradient_computation_round(server_state, federated_dataset):
            """Orchestration logic for one round of gradient computation.
            Args:
              server_state: A `ServerState`.
              federated_dataset: A federated `tf.data.Dataset` with placement
                `tff.CLIENTS`.
            Returns:
            A tuple of updated `tf.Tensor` of clients initial probability and `ClientOutput`.
            """
            server_message = tff.federated_map(server_message_fn, server_state)
            server_message_at_client = tff.federated_broadcast(server_message)

            client_outputs = tff.federated_map(
                client_update_fn,
                (federated_dataset, server_message_at_client))

            update_norm_sum_weighted = tff.federated_sum(
                client_outputs.update_norm_weighted)
            norm_sum_clients_weighted = tff.federated_broadcast(
                update_norm_sum_weighted)

            prob_init = scale_on_clients(client_outputs.update_norm_weighted,
                                         norm_sum_clients_weighted)
            return prob_init, client_outputs

        @tff.federated_computation(tff.FederatedType(tf.float32, tff.CLIENTS))
        def run_one_inner_loop_weights_computation(prob):
            """Orchestration logic for one round of computation.
            Args:
              prob: Probability of each client to communicate update.
            Returns:
            A tuple of updated `Probabilities` and `tf.float32` of rescaling factor.
            """

            prob_message = create_prob_message_on_clients(prob)
            prob_aggreg = tff.federated_sum(prob_message)
            rescaling_factor_master = compute_rescaling_on_master(prob_aggreg)
            rescaling_factor_clients = tff.federated_broadcast(
                rescaling_factor_master)
            prob = rescale_prob_on_clients(prob, rescaling_factor_clients)

            return prob, rescaling_factor_master

        @tff.federated_computation
        def server_init_tff():
            """Orchestration logic for server model initialization."""
            return tff.federated_value(server_init_tf(), tff.SERVER)

        def run_one_round(server_state, federated_dataset):
            """Orchestration logic for one round of computation.
            Args:
              server_state: A `ServerState`.
              federated_dataset: A federated `tf.data.Dataset` with placement
                `tff.CLIENTS`.
            Returns:
            A tuple of updated `ServerState` and `tf.Tensor` of average loss.
            """
            prob, client_outputs = run_gradient_computation_round(
                server_state, federated_dataset)

            if importance_sampling:
                for j in range(j_max):
                    prob, rescaling_factor = run_one_inner_loop_weights_computation(
                        prob)
                    if rescaling_factor <= 1:
                        break

            weight_denom = [
                client_output.client_weight for client_output in client_outputs
            ]
            weights_delta = [
                client_output.weights_delta for client_output in client_outputs
            ]

            # rescale weights based on sampling procedure
            weights_is = compute_weights_is(prob)
            weights_delta = rescale_and_remove(weights_delta, weights_is)

            round_model_delta = compute_round_model_delta(
                weights_delta, weight_denom)

            server_state = update_server_state(server_state, round_model_delta)

            model_output = [
                client_output.model_output for client_output in client_outputs
            ]
            round_loss_metric = compute_loss_metric(model_output, weight_denom)

            prob_numpy = []
            for p in prob:
                prob_numpy.append(p.numpy())

            return server_state, round_loss_metric, prob_numpy

        self.next = run_one_round
        self.initialize = server_init_tff
Beispiel #20
0
def build_gan_training_process(gan: GanFnsAndTypes):
    """Constructs a `tff.Computation` for GAN training.

  Args:
    gan: A `GanFnsAndTypes` object.

  Returns:
    A `tff.utils.IterativeProcess` for GAN training.
  """

    # Generally, it is easiest to get the types correct by building
    # all of the needed tf_computations first, since this ensures we only
    # have non-federated types.
    server_initial_state = build_server_initial_state_comp(gan)
    server_state_type = server_initial_state.type_signature.result
    client_computation = build_client_computation(gan)
    client_output_type = client_computation.type_signature.result
    server_computation = build_server_computation(gan, server_state_type,
                                                  client_output_type)

    @tff.federated_computation
    def fed_server_initial_state():
        return tff.federated_value(server_initial_state(), tff.SERVER)

    @tff.federated_computation(tff.FederatedType(server_state_type,
                                                 tff.SERVER),
                               gan.server_gen_input_type,
                               gan.client_gen_input_type,
                               gan.client_real_data_type)
    def run_one_round(server_state, server_gen_inputs, client_gen_inputs,
                      client_real_data):
        """The `tff.Computation` to be returned."""
        # TODO(b/131429028): The federated_zip should be automatic.
        from_server = tff.federated_zip(
            gan_training_tf_fns.FromServer(
                generator_weights=server_state.generator_weights,
                discriminator_weights=server_state.discriminator_weights))
        client_input = tff.federated_broadcast(from_server)
        client_outputs = tff.federated_map(
            client_computation,
            (client_gen_inputs, client_real_data, client_input))

        if gan.dp_averaging_fn is None:
            # Not using differential privacy.
            new_dp_averaging_state = server_state.dp_averaging_state
            averaged_discriminator_weights_delta = tff.federated_mean(
                client_outputs.discriminator_weights_delta,
                weight=client_outputs.update_weight)
        else:
            # Using differential privacy. Note that the weight argument is set to None
            # here. This is because the DP aggregation code explicitly does not do
            # weighted aggregation. (If weighted aggregation is desired, differential
            # privacy needs to be turned off.)
            new_dp_averaging_state, averaged_discriminator_weights_delta = (
                gan.dp_averaging_fn(server_state.dp_averaging_state,
                                    client_outputs.discriminator_weights_delta,
                                    weight=None))

        # TODO(b/131085687): Perhaps reconsider the choice to also use
        # ClientOutput to hold the aggregated client output.
        aggregated_client_output = gan_training_tf_fns.ClientOutput(
            discriminator_weights_delta=averaged_discriminator_weights_delta,
            # We don't actually need the aggregated update_weight, but
            # this keeps the types of the non-aggregated and aggregated
            # client_output the same, which is convenient. And I can
            # imagine wanting this.
            update_weight=tff.federated_sum(client_outputs.update_weight),
            counters=tff.federated_sum(client_outputs.counters))

        # TODO(b/131839522): This federated_zip shouldn't be needed.
        aggregated_client_output = tff.federated_zip(aggregated_client_output)

        server_state = tff.federated_map(
            server_computation,
            (server_state, server_gen_inputs, aggregated_client_output,
             new_dp_averaging_state))
        return server_state

    return tff.utils.IterativeProcess(fed_server_initial_state, run_one_round)
Beispiel #21
0
def build_fed_avg_process(
    model_fn: ModelBuilder,
    client_lr_callback: callbacks.ReduceLROnPlateau,
    server_lr_callback: callbacks.ReduceLROnPlateau,
    client_optimizer_fn: OptimizerBuilder = tf.keras.optimizers.SGD,
    server_optimizer_fn: OptimizerBuilder = tf.keras.optimizers.SGD,
    client_weight_fn: Optional[ClientWeightFn] = None,
    dataset_preprocess_comp: Optional[tff.Computation] = None,
) -> tff.templates.IterativeProcess:
    """Builds the TFF computations for FedAvg with learning rate decay.

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`.
    client_lr_callback: A `ReduceLROnPlateau` callback.
    server_lr_callback: A `ReduceLROnPlateau` callback.
    client_optimizer_fn: A function that accepts a `learning_rate` keyword
      argument and returns a `tf.keras.optimizers.Optimizer` instance.
    server_optimizer_fn: A function that accepts a `learning_rate` argument and
      returns a `tf.keras.optimizers.Optimizer` instance.
    client_weight_fn: Optional function that takes the output of
      `model.report_local_outputs` and returns a tensor that provides the weight
      in the federated average of model deltas. If not provided, the default is
      the total number of examples processed on device.
    dataset_preprocess_comp: Optional `tff.Computation` that sets up a data
      pipeline on the clients. The computation must take a squence of values
      and return a sequence of values, or in TFF type shorthand `(U* -> V*)`. If
      `None`, no dataset preprocessing is applied.

  Returns:
    A `tff.templates.IterativeProcess`.
  """

    dummy_model = model_fn()
    client_monitor = client_lr_callback.monitor
    server_monitor = server_lr_callback.monitor

    server_init_tf = build_server_init_fn(model_fn, server_optimizer_fn,
                                          client_lr_callback,
                                          server_lr_callback)

    server_state_type = server_init_tf.type_signature.result
    model_weights_type = server_state_type.model

    if dataset_preprocess_comp is not None:
        tf_dataset_type = dataset_preprocess_comp.type_signature.parameter
        model_input_type = tff.SequenceType(dummy_model.input_spec)
        preprocessed_dataset_type = dataset_preprocess_comp.type_signature.result
        if not model_input_type.is_assignable_from(preprocessed_dataset_type):
            raise TypeError(
                'Supplied `dataset_preprocess_comp` does not yield '
                'batches that are compatible with the model constructed '
                'by `model_fn`. Model expects type {m}, but dataset '
                'yields type {d}.'.format(m=model_input_type,
                                          d=preprocessed_dataset_type))
    else:
        tf_dataset_type = tff.SequenceType(dummy_model.input_spec)
        model_input_type = tff.SequenceType(dummy_model.input_spec)

    client_lr_type = server_state_type.client_lr_callback.learning_rate
    client_monitor_value_type = server_state_type.client_lr_callback.best
    server_monitor_value_type = server_state_type.server_lr_callback.best

    @tff.tf_computation(model_input_type, model_weights_type, client_lr_type)
    def client_update_fn(tf_dataset, initial_model_weights, client_lr):
        client_optimizer = client_optimizer_fn(client_lr)
        initial_model_output = get_client_output(model_fn(), tf_dataset,
                                                 initial_model_weights)
        client_state = client_update(model_fn(), tf_dataset,
                                     initial_model_weights, client_optimizer,
                                     client_weight_fn)
        return tff.utils.update_state(
            client_state, initial_model_output=initial_model_output)

    @tff.tf_computation(server_state_type, model_weights_type.trainable,
                        client_monitor_value_type, server_monitor_value_type)
    def server_update_fn(server_state, model_delta, client_monitor_value,
                         server_monitor_value):
        model = model_fn()
        server_lr = server_state.server_lr_callback.learning_rate
        server_optimizer = server_optimizer_fn(server_lr)
        # We initialize the server optimizer variables to avoid creating them
        # within the scope of the tf.function server_update.
        _initialize_optimizer_vars(model, server_optimizer)
        return server_update(model, server_optimizer, server_state,
                             model_delta, client_monitor_value,
                             server_monitor_value)

    @tff.federated_computation(tff.FederatedType(server_state_type,
                                                 tff.SERVER),
                               tff.FederatedType(tf_dataset_type, tff.CLIENTS))
    def run_one_round(server_state, federated_dataset):
        """Orchestration logic for one round of computation.

    Note that in addition to updating the server weights according to the client
    model weight deltas, we extract metrics (governed by the `monitor` attribute
    of the `client_lr_callback` and `server_lr_callback` attributes of the
    `server_state`) and use these to update the client learning rate callbacks.

    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`.

    Returns:
      A tuple of updated `ServerState` and the result of
      `tff.learning.Model.federated_output_computation` before and during local
      client training.
    """
        client_model = tff.federated_broadcast(server_state.model)
        client_lr = tff.federated_broadcast(
            server_state.client_lr_callback.learning_rate)

        if dataset_preprocess_comp is not None:
            federated_dataset = tff.federated_map(dataset_preprocess_comp,
                                                  federated_dataset)
        client_outputs = tff.federated_map(
            client_update_fn, (federated_dataset, client_model, client_lr))

        client_weight = client_outputs.client_weight
        aggregated_gradients = tff.federated_mean(
            client_outputs.accumulated_gradients, weight=client_weight)

        initial_aggregated_outputs = dummy_model.federated_output_computation(
            client_outputs.initial_model_output)
        if isinstance(initial_aggregated_outputs.type_signature,
                      tff.StructType):
            initial_aggregated_outputs = tff.federated_zip(
                initial_aggregated_outputs)

        aggregated_outputs = dummy_model.federated_output_computation(
            client_outputs.model_output)
        if isinstance(aggregated_outputs.type_signature, tff.StructType):
            aggregated_outputs = tff.federated_zip(aggregated_outputs)
        client_monitor_value = initial_aggregated_outputs[client_monitor]
        server_monitor_value = initial_aggregated_outputs[server_monitor]

        server_state = tff.federated_map(
            server_update_fn, (server_state, aggregated_gradients,
                               client_monitor_value, server_monitor_value))

        result = collections.OrderedDict(
            before_training=initial_aggregated_outputs,
            during_training=aggregated_outputs)

        return server_state, result

    @tff.federated_computation
    def initialize_fn():
        return tff.federated_value(server_init_tf(), tff.SERVER)

    return tff.templates.IterativeProcess(initialize_fn=initialize_fn,
                                          next_fn=run_one_round)
Beispiel #22
0
import numpy as np
import tensorflow as tf
import tensorflow_federated as tff

from federated_aggregations import paillier

NUM_CLIENTS = 5

paillier_factory = paillier.local_paillier_executor_factory(NUM_CLIENTS)
paillier_context = tff.framework.ExecutionContext(paillier_factory)
tff.framework.set_default_context(paillier_context)


@tff.federated_computation(
    tff.FederatedType(tff.TensorType(tf.int32, [2]), tff.CLIENTS),
    tff.TensorType(tf.int32))
def secure_paillier_addition(x, bitwidth):
    return tff.federated_secure_sum(x, bitwidth)


base = np.array([1, 2], np.int32)
x = [base + i for i in range(NUM_CLIENTS)]
result = secure_paillier_addition(x, 32)
print(result)
Beispiel #23
0
        else:
            false_positive_rate[threshold] = 0.0
            false_discovery_rate[threshold] = 0.0
            harmonic_mean_fpr_fdr[threshold] = 0.0

        # The leaked_words in the next round must be a subset of this round.
        leaked_words_candidates = leaked_words
        bisect_upper_bound = below_threshold_index

    return false_positive_rate, false_discovery_rate, harmonic_mean_fpr_fdr


@tff.tf_computation(tff.SequenceType(tf.string))
def compute_lossless_result_per_user(dataset):
    # Do not have limit on each client's contribution in this case.
    k_words = get_top_elements(dataset, tf.constant(tf.int32.max))
    return k_words


@tff.federated_computation(
    tff.FederatedType(tff.SequenceType(tf.string), tff.CLIENTS))
def compute_lossless_results_federated(datasets):
    words = tff.federated_map(compute_lossless_result_per_user, datasets)
    return words


def compute_lossless_results(datasets):
    all_words = tf.concat(compute_lossless_results_federated(datasets), axis=0)
    word, _, count = tf.unique_with_counts(all_words)
    return dict(zip(word.numpy(), count.numpy()))
Beispiel #24
0
        return batch_train(model, batch, learning_rate)

    l = tff.sequence_reduce(all_batches, initial_model, batch_fn)
    return l


@tff.federated_computation(MODEL_TYPE, LOCAL_DATA_TYPE)
def local_eval(model, all_batches):
    #
    return tff.sequence_sum(
        tff.sequence_map(
            tff.federated_computation(lambda b: batch_loss(model, b),
                                      BATCH_TYPE), all_batches))


SERVER_MODEL_TYPE = tff.FederatedType(MODEL_TYPE, tff.SERVER, all_equal=True)
CLIENT_DATA_TYPE = tff.FederatedType(LOCAL_DATA_TYPE, tff.CLIENTS)


@tff.federated_computation(SERVER_MODEL_TYPE, CLIENT_DATA_TYPE)
def federated_eval(model, data):
    return tff.federated_mean(
        tff.federated_map(local_eval, [tff.federated_broadcast(model), data]))


SERVER_FLOAT_TYPE = tff.FederatedType(tf.float32, tff.SERVER, all_equal=True)


@tff.federated_computation(SERVER_MODEL_TYPE, SERVER_FLOAT_TYPE,
                           CLIENT_DATA_TYPE)
def federated_train(model, learning_rate, data):
def build_fed_avg_process(
    total_clients: int,
    effective_num_clients: int,
    model_fn: ModelBuilder,
    client_optimizer_fn: OptimizerBuilder,
    client_lr: Union[float, LRScheduleFn] = 0.1,
    server_optimizer_fn: OptimizerBuilder = tf.keras.optimizers.SGD,
    server_lr: Union[float, LRScheduleFn] = 1.0,
    client_weight_fn: Optional[ClientWeightFn] = None,
    aggregation_process: Optional[measured_process.MeasuredProcess] = None,
) -> tff.templates.IterativeProcess:
    """Builds the TFF computations for optimization using federated averaging.

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`.
    client_optimizer_fn: A function that accepts a `learning_rate` keyword
      argument and returns a `tf.keras.optimizers.Optimizer` instance.
    client_lr: A scalar learning rate or a function that accepts a float
      `round_num` argument and returns a learning rate.
    server_optimizer_fn: A function that accepts a `learning_rate` argument and
      returns a `tf.keras.optimizers.Optimizer` instance.
    server_lr: A scalar learning rate or a function that accepts a float
      `round_num` argument and returns a learning rate.
    client_weight_fn: Optional function that takes the output of
      `model.report_local_outputs` and returns a tensor that provides the weight
      in the federated average of model deltas. If not provided, the default is
      the total number of examples processed on device.

  Returns:
    A `tff.templates.IterativeProcess`.
  """

    client_lr_schedule = client_lr
    if not callable(client_lr_schedule):
        client_lr_schedule = lambda round_num: client_lr

    server_lr_schedule = server_lr
    if not callable(server_lr_schedule):
        server_lr_schedule = lambda round_num: server_lr

    with tf.Graph().as_default():
        dummy_model = model_fn()
        model_weights_type = model_utils.weights_type_from_model(dummy_model)
        dummy_optimizer = server_optimizer_fn()
        _initialize_optimizer_vars(dummy_model, dummy_optimizer)
        optimizer_variable_type = type_conversions.type_from_tensors(
            dummy_optimizer.variables())

    if aggregation_process is None:
        aggregation_process = build_stateless_mean(
            model_delta_type=model_weights_type.trainable)
    if not _is_valid_aggregation_process(aggregation_process):
        raise ProcessTypeError(
            'aggregation_process type signature does not conform to expected '
            'signature (<state@S, input@C> -> <state@S, result@S, measurements@S>).'
            ' Got: {t}'.format(t=aggregation_process.next.type_signature))

    initialize_computation = build_server_init_fn(
        model_fn=model_fn,
        effective_num_clients=effective_num_clients,
        # Initialize with the learning rate for round zero.
        server_optimizer_fn=lambda: server_optimizer_fn(server_lr_schedule(0)),
        aggregation_process=aggregation_process)

    # server_state_type = initialize_computation.type_signature.result
    # model_weights_type = server_state_type.model
    round_num_type = tf.float32

    tf_dataset_type = tff.SequenceType(dummy_model.input_spec)
    model_input_type = tff.SequenceType(dummy_model.input_spec)

    client_losses_at_server_type = tff.TensorType(dtype=tf.float32,
                                                  shape=[total_clients, 1])
    clients_weights_at_server_type = tff.TensorType(dtype=tf.float32,
                                                    shape=[total_clients, 1])

    aggregation_state = aggregation_process.initialize.type_signature.result.member

    server_state_type = ServerState(
        model=model_weights_type,
        optimizer_state=optimizer_variable_type,
        round_num=round_num_type,
        effective_num_clients=tf.int32,
        delta_aggregate_state=aggregation_state,
    )

    # @computations.tf_computation(clients_weights_type)
    # def get_zero_weights_all_clients(weights):
    #   return tf.zeros_like(weights, dtype=tf.float32)

    ######################################################
    # def federated_output(local_outputs):
    #   return federated_aggregate_keras_metric(self.get_metrics(), local_outputs)

    # federated_output_computation = computations.federated_computation(
    #       federated_output, federated_local_outputs_type)

    single_id_type = tff.TensorType(dtype=tf.int32, shape=[1, 1])

    @tff.tf_computation(model_input_type, model_weights_type, round_num_type,
                        single_id_type)
    def client_update_fn(tf_dataset, initial_model_weights, round_num,
                         client_id):
        client_lr = client_lr_schedule(round_num)
        client_optimizer = client_optimizer_fn(client_lr)
        client_update = create_client_update_fn()
        return client_update(model_fn(), tf_dataset, initial_model_weights,
                             client_optimizer, client_id, client_weight_fn)

    @tff.tf_computation(server_state_type, model_weights_type.trainable)
    def server_update_fn(server_state, model_delta):
        model = model_fn()
        server_lr = server_lr_schedule(server_state.round_num)
        server_optimizer = server_optimizer_fn(server_lr)
        # We initialize the server optimizer variables to avoid creating them
        # within the scope of the tf.function server_update.
        _initialize_optimizer_vars(model, server_optimizer)
        return server_update(model, server_optimizer, server_state,
                             model_delta)

    id_type = tff.TensorType(shape=[1, 1], dtype=tf.int32)

    @tff.tf_computation(clients_weights_at_server_type, id_type)
    def select_weight_fn(clients_weights, local_id):
        return select_weight(clients_weights, local_id)

    @tff.tf_computation(client_losses_at_server_type,
                        clients_weights_at_server_type, tf.int32)
    def zero_small_loss_clients(losses_at_server, weights_at_server,
                                effective_num_clients):
        """Receives losses and returns participating clients.

    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`.

    Returns:
      A tuple of updated `ServerState` and the result of
      `tff.learning.Model.federated_output_computation`.
    """
        return redefine_client_weight(losses_at_server, weights_at_server,
                                      effective_num_clients)

    # @tff.tf_computation(client_losses_type)
    # def dataset_to_tensor_fn(dataset):
    #   return dataset_to_tensor(dataset)
    @tff.federated_computation(tff.FederatedType(server_state_type,
                                                 tff.SERVER),
                               tff.FederatedType(tf_dataset_type, tff.CLIENTS),
                               tff.FederatedType(id_type, tff.CLIENTS))
    def run_one_round(server_state, federated_dataset, ids):
        """Orchestration logic for one round of computation.

    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`.

    Returns:
      A tuple of updated `ServerState` and the result of
      `tff.learning.Model.federated_output_computation`.
    """
        client_model = tff.federated_broadcast(server_state.model)
        client_round_num = tff.federated_broadcast(server_state.round_num)

        client_outputs = tff.federated_map(
            client_update_fn,
            (federated_dataset, client_model, client_round_num, ids))

        client_weight = client_outputs.client_weight
        client_id = client_outputs.client_id

        #LOSS SELECTION:
        # losses_at_server = tff.federated_collect(client_outputs.model_output)
        # weights_at_server = tff.federated_collect(client_weight)
        @computations.tf_computation
        def zeros_fn():
            return tf.zeros(shape=[total_clients, 1], dtype=tf.float32)

        zero = zeros_fn()

        at_server_type = tff.TensorType(shape=[total_clients, 1],
                                        dtype=tf.float32)
        # list_type = tff.SequenceType( tff.TensorType(dtype=tf.float32))
        client_output_type = client_update_fn.type_signature.result

        @computations.tf_computation(at_server_type, client_output_type)
        def accumulate_weight(u, t):
            value = t.client_weight
            index = t.client_id
            new_u = tf.tensor_scatter_nd_update(u, index, value)
            return new_u

        @computations.tf_computation(at_server_type, client_output_type)
        def accumulate_loss(u, t):
            value = tf.reshape(tf.math.reduce_sum(t.model_output['loss']),
                               shape=[1, 1])
            index = t.client_id
            new_u = tf.tensor_scatter_nd_update(u, index, value)
            return new_u

        # output_at_server= tff.federated_collect(client_outputs)

        weights_at_server = tff.federated_reduce(client_outputs, zero,
                                                 accumulate_weight)
        losses_at_server = tff.federated_reduce(client_outputs, zero,
                                                accumulate_loss)
        #losses_at_server = tff.federated_aggregate(client_outputs.model_output, zero, accumulate, merge, report)

        selected_clients_weights = tff.federated_map(
            zero_small_loss_clients, (losses_at_server, weights_at_server,
                                      server_state.effective_num_clients))

        # selected_clients_weights_at_client = tff.federated_broadcast(selected_clients_weights)

        selected_clients_weights_broadcast = tff.federated_broadcast(
            selected_clients_weights)

        selected_clients_weights_at_client = tff.federated_map(
            select_weight_fn, (selected_clients_weights_broadcast, ids))

        aggregation_output = aggregation_process.next(
            server_state.delta_aggregate_state, client_outputs.weights_delta,
            selected_clients_weights_at_client)

        # model_delta = tff.federated_mean(
        #     client_outputs.weights_delta, weight=client_weight)

        server_state = tff.federated_map(
            server_update_fn, (server_state, aggregation_output.result))

        aggregated_outputs = dummy_model.federated_output_computation(
            client_outputs.model_output)
        if aggregated_outputs.type_signature.is_struct():
            aggregated_outputs = tff.federated_zip(aggregated_outputs)

        return server_state, aggregated_outputs

    # @tff.federated_computation
    # def initialize_fn():
    #   return tff.federated_value(server_init_tf(), tff.SERVER)

    return tff.templates.IterativeProcess(initialize_fn=initialize_computation,
                                          next_fn=run_one_round)
Beispiel #26
0
def build_triehh_process(possible_prefix_extensions: List[str],
                         num_sub_rounds: int,
                         max_num_heavy_hitters: int,
                         max_user_contribution: int,
                         default_terminator: str = '$'):
  """Builds the TFF computations for heavy hitters discovery with TrieHH.

  TrieHH works by interactively keeping track of popular prefixes. In each
  round, the server broadcasts the popular prefixes it has
  discovered so far and the list of `possible_prefix_extensions` to a small
  fraction of selected clients. The select clients sample
  `max_user_contributions` words from their local datasets, and use them to vote
  on character extensions to the broadcasted popular prefixes. Client votes are
  accumulated across `num_sub_rounds` rounds, and then the top
  `max_num_heavy_hitters` extensions are used to extend the already discovered
  prefixes, and the extended prefixes are used in the next round. When an
  already discovered prefix is extended by `default_terminator` it is added to
  the list of discovered heavy hitters.

  Args:
    possible_prefix_extensions: A list containing all the possible extensions to
      learned prefixes. Each extensions must be a single character strings.
    num_sub_rounds: The total number of sub rounds to be executed before
      decoding aggregated votes. Must be positive.
    max_num_heavy_hitters: The maximum number of discoverable heavy hitters.
      Must be positive.
    max_user_contribution: The maximum number of examples a user can contribute.
      Must be positive.
    default_terminator: The end of sequence symbol.

  Returns:
    A `tff.utils.IterativeProcess`.
  """

  @tff.tf_computation
  def server_init_tf():
    return ServerState(
        discovered_heavy_hitters=tf.constant([], dtype=tf.string),
        discovered_prefixes=tf.constant([''], dtype=tf.string),
        possible_prefix_extensions=tf.constant(
            possible_prefix_extensions, dtype=tf.string),
        round_num=tf.constant(0, dtype=tf.int32),
        accumulated_votes=tf.zeros(
            dtype=tf.int32,
            shape=[max_num_heavy_hitters,
                   len(possible_prefix_extensions)]))

  # We cannot use server_init_tf.type_signature.result because the
  # discovered_* fields need to have [None] shapes, since they will grow over
  # time.
  server_state_type = (
      tff.to_type(
          ServerState(
              discovered_heavy_hitters=tff.TensorType(
                  dtype=tf.string, shape=[None]),
              discovered_prefixes=tff.TensorType(dtype=tf.string, shape=[None]),
              possible_prefix_extensions=tff.TensorType(
                  dtype=tf.string, shape=[len(possible_prefix_extensions)]),
              round_num=tff.TensorType(dtype=tf.int32, shape=[]),
              accumulated_votes=tff.TensorType(
                  dtype=tf.int32, shape=[None,
                                         len(possible_prefix_extensions)]),
          )))

  sub_round_votes_type = tff.TensorType(
      dtype=tf.int32,
      shape=[max_num_heavy_hitters,
             len(possible_prefix_extensions)])

  @tff.tf_computation(server_state_type, sub_round_votes_type)
  @tf.function
  def server_update_fn(server_state, sub_round_votes):
    server_state = server_update(
        server_state,
        sub_round_votes,
        num_sub_rounds=tf.constant(num_sub_rounds),
        max_num_heavy_hitters=tf.constant(max_num_heavy_hitters),
        default_terminator=tf.constant(default_terminator, dtype=tf.string))
    return server_state

  tf_dataset_type = tff.SequenceType(tf.string)
  discovered_prefixes_type = tff.TensorType(dtype=tf.string, shape=[None])
  round_num_type = tff.TensorType(dtype=tf.int32, shape=[])

  @tff.tf_computation(tf_dataset_type, discovered_prefixes_type, round_num_type)
  @tf.function
  def client_update_fn(tf_dataset, discovered_prefixes, round_num):
    result = client_update(tf_dataset, discovered_prefixes,
                           tf.constant(possible_prefix_extensions), round_num,
                           num_sub_rounds, max_num_heavy_hitters,
                           max_user_contribution)
    return result

  federated_server_state_type = tff.FederatedType(server_state_type, tff.SERVER)
  federated_dataset_type = tff.FederatedType(
      tf_dataset_type, tff.CLIENTS, all_equal=False)

  @tff.federated_computation(federated_server_state_type,
                             federated_dataset_type)
  def run_one_round(server_state, federated_dataset):
    """Orchestration logic for one round of TrieHH computation.

    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`.

    Returns:
      An updated `ServerState`
    """
    discovered_prefixes = tff.federated_broadcast(
        server_state.discovered_prefixes)
    round_num = tff.federated_broadcast(server_state.round_num)

    client_outputs = tff.federated_map(
        client_update_fn,
        tff.federated_zip([federated_dataset, discovered_prefixes, round_num]))

    accumulated_votes = tff.federated_sum(client_outputs.client_votes)

    server_state = tff.federated_map(server_update_fn,
                                     (server_state, accumulated_votes))

    server_output = tff.federated_value([], tff.SERVER)

    return server_state, server_output

  return tff.utils.IterativeProcess(
      initialize_fn=tff.federated_computation(
          lambda: tff.federated_value(server_init_tf(), tff.SERVER)),
      next_fn=run_one_round)
Beispiel #27
0
def build_triehh_process(
    possible_prefix_extensions: List[str],
    num_sub_rounds: int,
    max_num_prefixes: int,
    threshold: int,
    max_user_contribution: int,
    default_terminator: str = triehh_tf.DEFAULT_TERMINATOR):
  """Builds the TFF computations for heavy hitters discovery with TrieHH.

  TrieHH works by interactively keeping track of popular prefixes. In each
  round, the server broadcasts the popular prefixes it has
  discovered so far and the list of `possible_prefix_extensions` to a small
  fraction of selected clients. The select clients sample
  `max_user_contributions` words from their local datasets, and use them to vote
  on character extensions to the broadcasted popular prefixes. Client votes are
  accumulated across `num_sub_rounds` rounds, and then the top
  `max_num_prefixes` extensions get at least 'threshold' votes are used to
  extend the already discovered
  prefixes, and the extended prefixes are used in the next round. When an
  already discovered prefix is extended by `default_terminator` it is added to
  the list of discovered heavy hitters.

  Args:
    possible_prefix_extensions: A list containing all the possible extensions to
      learned prefixes. Each extensions must be a single character strings. This
      list should not contain the default_terminator.
    num_sub_rounds: The total number of sub rounds to be executed before
      decoding aggregated votes. Must be positive.
    max_num_prefixes: The maximum number of prefixes we can keep in the trie.
      Must be positive.
    threshold: The threshold for heavy hitters and discovered prefixes. Only
      those get at least `threshold` votes are discovered. Must be positive.
    max_user_contribution: The maximum number of examples a user can contribute.
      Must be positive.
    default_terminator: The end of sequence symbol.

  Returns:
    A `tff.templates.IterativeProcess`.

  Raises:
    ValueError: If possible_prefix_extensions contains default_terminator.
  """
  if default_terminator in possible_prefix_extensions:
    raise ValueError(
        'default_terminator should not appear in possible_prefix_extensions')

  # Append `default_terminator` to `possible_prefix_extensions` to make sure it
  # is the last item in the list.
  possible_prefix_extensions.append(default_terminator)

  @tff.tf_computation
  def server_init_tf():
    return ServerState(
        discovered_heavy_hitters=tf.constant([], dtype=tf.string),
        heavy_hitter_frequencies=tf.constant([], dtype=tf.float64),
        discovered_prefixes=tf.constant([''], dtype=tf.string),
        round_num=tf.constant(0, dtype=tf.int32),
        accumulated_votes=tf.zeros(
            dtype=tf.int32,
            shape=[max_num_prefixes,
                   len(possible_prefix_extensions)]),
        accumulated_weights=tf.constant(0, dtype=tf.int32))

  # We cannot use server_init_tf.type_signature.result because the
  # discovered_* fields need to have [None] shapes, since they will grow over
  # time.
  server_state_type = (
      tff.to_type(
          ServerState(
              discovered_heavy_hitters=tff.TensorType(
                  dtype=tf.string, shape=[None]),
              heavy_hitter_frequencies=tff.TensorType(
                  dtype=tf.float64, shape=[None]),
              discovered_prefixes=tff.TensorType(dtype=tf.string, shape=[None]),
              round_num=tff.TensorType(dtype=tf.int32, shape=[]),
              accumulated_votes=tff.TensorType(
                  dtype=tf.int32, shape=[None,
                                         len(possible_prefix_extensions)]),
              accumulated_weights=tff.TensorType(dtype=tf.int32, shape=[]),
          )))

  sub_round_votes_type = tff.TensorType(
      dtype=tf.int32, shape=[max_num_prefixes,
                             len(possible_prefix_extensions)])
  sub_round_weight_type = tff.TensorType(dtype=tf.int32, shape=[])

  @tff.tf_computation(server_state_type, sub_round_votes_type,
                      sub_round_weight_type)
  def server_update_fn(server_state, sub_round_votes, sub_round_weight):
    return server_update(
        server_state,
        tf.constant(possible_prefix_extensions),
        sub_round_votes,
        sub_round_weight,
        num_sub_rounds=tf.constant(num_sub_rounds),
        max_num_prefixes=tf.constant(max_num_prefixes),
        threshold=tf.constant(threshold))

  tf_dataset_type = tff.SequenceType(tf.string)
  discovered_prefixes_type = tff.TensorType(dtype=tf.string, shape=[None])
  round_num_type = tff.TensorType(dtype=tf.int32, shape=[])

  @tff.tf_computation(tf_dataset_type, discovered_prefixes_type, round_num_type)

  def client_update_fn(tf_dataset, discovered_prefixes, round_num):
    return client_update(tf_dataset, discovered_prefixes,
                         tf.constant(possible_prefix_extensions), round_num,
                         num_sub_rounds, max_num_prefixes,
                         max_user_contribution)

  federated_server_state_type = tff.FederatedType(server_state_type, tff.SERVER)
  federated_dataset_type = tff.FederatedType(
      tf_dataset_type, tff.CLIENTS, all_equal=False)

  @tff.federated_computation(federated_server_state_type,
                             federated_dataset_type)
  def run_one_round(server_state, federated_dataset):
    """Orchestration logic for one round of TrieHH computation.

    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`.

    Returns:
      An updated `ServerState`
    """
    discovered_prefixes = tff.federated_broadcast(
        server_state.discovered_prefixes)
    round_num = tff.federated_broadcast(server_state.round_num)

    client_outputs = tff.federated_map(
        client_update_fn, (federated_dataset, discovered_prefixes, round_num))

    accumulated_votes = tff.federated_sum(client_outputs.client_votes)

    accumulated_weights = tff.federated_sum(client_outputs.client_weight)

    server_state = tff.federated_map(
        server_update_fn,
        (server_state, accumulated_votes, accumulated_weights))

    server_output = tff.federated_value([], tff.SERVER)

    return server_state, server_output

  return tff.templates.IterativeProcess(
      initialize_fn=tff.federated_computation(
          lambda: tff.federated_eval(server_init_tf, tff.SERVER)),
      next_fn=run_one_round)