Esempio n. 1
0
def _instantiate_aggregation_process(
        aggregation_factory, model_weights_type,
        client_weight_fn) -> tff.templates.AggregationProcess:
    """Constructs aggregation process given factory, checking compatibilty."""
    if aggregation_factory is None:
        aggregation_factory = tff.aggregators.MeanFactory()
        aggregation_process = aggregation_factory.create(
            model_weights_type.trainable, tff.TensorType(tf.float32))
    else:
        # We give precedence to unweighted aggregation.
        if isinstance(aggregation_factory,
                      tff.aggregators.UnweightedAggregationFactory):
            if client_weight_fn is not None:
                logging.warning(
                    'When using an unweighted aggregation, '
                    '`client_weight_fn` should not be specified; found '
                    '`client_weight_fn` %s', client_weight_fn)
            aggregation_process = aggregation_factory.create(
                model_weights_type.trainable)
        elif isinstance(aggregation_factory,
                        tff.aggregators.WeightedAggregationFactory):
            aggregation_process = aggregation_factory.create(
                model_weights_type.trainable, tff.TensorType(tf.float32))
        else:
            raise ValueError('Unknown type of aggregation factory: {}'.format(
                type(aggregation_factory)))
    return aggregation_process
    def test_build_jax_federated_averaging_process(self):
        batch_type = collections.OrderedDict([
            ('pixels', tff.TensorType(np.float32, (50, 784))),
            ('labels', tff.TensorType(np.int32, (50, )))
        ])

        def random_batch():
            pixels = np.random.uniform(low=0.0, high=1.0,
                                       size=(50, 784)).astype(np.float32)
            labels = np.random.randint(low=0,
                                       high=9,
                                       size=(50, ),
                                       dtype=np.int32)
            return collections.OrderedDict([('pixels', pixels),
                                            ('labels', labels)])

        model_type = collections.OrderedDict([
            ('weights', tff.TensorType(np.float32, (784, 10))),
            ('bias', tff.TensorType(np.float32, (10, )))
        ])

        def loss(model, batch):
            y = jax.nn.softmax(
                jax.numpy.add(
                    jax.numpy.matmul(batch['pixels'], model['weights']),
                    model['bias']))
            targets = jax.nn.one_hot(jax.numpy.reshape(batch['labels'], -1),
                                     10)
            return -jax.numpy.mean(
                jax.numpy.sum(targets * jax.numpy.log(y), axis=1))

        trainer = jax_components.build_jax_federated_averaging_process(
            batch_type, model_type, loss, step_size=0.001)

        trainer.next(trainer.initialize(), [[random_batch()]])
  def test_federated_mean_masked(self):

    value_type = tff.StructType([('a', tff.TensorType(tf.float32, shape=[3])),
                                 ('b', tff.TensorType(tf.float32, shape=[2,
                                                                         1]))])
    weight_type = tff.TensorType(tf.float32)

    federated_mean_masked_fn = fed_pa_schedule.build_federated_mean_masked(
        value_type, weight_type)

    # Check type signature.
    expected_type = tff.FunctionType(
        parameter=collections.OrderedDict(
            value=tff.FederatedType(value_type, tff.CLIENTS),
            weight=tff.FederatedType(weight_type, tff.CLIENTS)),
        result=tff.FederatedType(value_type, tff.SERVER))
    self.assertTrue(
        federated_mean_masked_fn.type_signature.is_equivalent_to(expected_type),
        msg='{s}\n!={t}'.format(
            s=federated_mean_masked_fn.type_signature, t=expected_type))

    # Check correctness of zero masking in the mean.
    values = [
        collections.OrderedDict(
            a=tf.constant([0.0, 1.0, 1.0]), b=tf.constant([[1.0], [2.0]])),
        collections.OrderedDict(
            a=tf.constant([1.0, 3.0, 0.0]), b=tf.constant([[3.0], [0.0]]))
    ]
    weights = [tf.constant(1.0), tf.constant(3.0)]
    output = federated_mean_masked_fn(values, weights)
    expected_output = collections.OrderedDict(
        a=tf.constant([1.0, 2.5, 1.0]), b=tf.constant([[2.5], [2.0]]))
    self.assertAllClose(output, expected_output)
Esempio n. 4
0
    def test_iterative_process_type_signature(self):
        client_lr_callback = callbacks.create_reduce_lr_on_plateau(
            learning_rate=0.1,
            min_delta=0.5,
            window_size=2,
            decay_factor=1.0,
            cooldown=0)
        server_lr_callback = callbacks.create_reduce_lr_on_plateau(
            learning_rate=0.1,
            min_delta=0.5,
            window_size=2,
            decay_factor=1.0,
            cooldown=0)

        iterative_process = adaptive_fed_avg.build_fed_avg_process(
            _uncompiled_model_builder,
            client_lr_callback,
            server_lr_callback,
            client_optimizer_fn=tf.keras.optimizers.SGD,
            server_optimizer_fn=tf.keras.optimizers.SGD)

        lr_callback_type = tff.framework.type_from_tensors(client_lr_callback)

        server_state_type = tff.FederatedType(
            adaptive_fed_avg.ServerState(model=tff.learning.ModelWeights(
                trainable=(tff.TensorType(tf.float32, [1, 1]),
                           tff.TensorType(tf.float32, [1])),
                non_trainable=()),
                                         optimizer_state=[tf.int64],
                                         client_lr_callback=lr_callback_type,
                                         server_lr_callback=lr_callback_type),
            tff.SERVER)

        self.assertEqual(
            iterative_process.initialize.type_signature,
            tff.FunctionType(parameter=None, result=server_state_type))

        dataset_type = tff.FederatedType(
            tff.SequenceType(
                collections.OrderedDict(
                    x=tff.TensorType(tf.float32, [None, 1]),
                    y=tff.TensorType(tf.float32, [None, 1]))), tff.CLIENTS)

        metrics_type = tff.FederatedType(
            collections.OrderedDict(loss=tff.TensorType(tf.float32)),
            tff.SERVER)
        output_type = collections.OrderedDict(before_training=metrics_type,
                                              during_training=metrics_type)

        expected_result_type = (server_state_type, output_type)
        expected_type = tff.FunctionType(parameter=collections.OrderedDict(
            server_state=server_state_type, federated_dataset=dataset_type),
                                         result=expected_result_type)

        actual_type = iterative_process.next.type_signature
        self.assertEqual(actual_type,
                         expected_type,
                         msg='{s}\n!={t}'.format(s=actual_type,
                                                 t=expected_type))
Esempio n. 5
0
def create_preprocess_fn(
    num_epochs: int,
    batch_size: int,
    shuffle_buffer_size: int = MAX_CLIENT_DATASET_SIZE,
    emnist_task: str = 'digit_recognition',
    num_parallel_calls: tf.Tensor = tf.data.experimental.AUTOTUNE
) -> tff.Computation:
    """Creates a preprocessing function for EMNIST client datasets.

  The preprocessing shuffles, repeats, batches, and then reshapes, using
  the `shuffle`, `repeat`, `batch`, and `map` attributes of a
  `tf.data.Dataset`, in that order.

  Args:
    num_epochs: An integer representing the number of epochs to repeat the
      client datasets.
    batch_size: An integer representing the batch size on clients.
    shuffle_buffer_size: An integer representing the shuffle buffer size on
      clients. If set to a number <= 1, no shuffling occurs.
    emnist_task: A string indicating the EMNIST task being performed. Must be
      one of 'digit_recognition' or 'autoencoder'. If the former, then elements
      are mapped to tuples of the form (pixels, label), if the latter then
      elements are mapped to tuples of the form (pixels, pixels).
    num_parallel_calls: An integer representing the number of parallel calls
      used when performing `tf.data.Dataset.map`.

  Returns:
    A `tff.Computation` performing the preprocessing discussed above.
  """
    if num_epochs < 1:
        raise ValueError('num_epochs must be a positive integer.')
    if shuffle_buffer_size <= 1:
        shuffle_buffer_size = 1

    if emnist_task == 'digit_recognition':
        mapping_fn = _reshape_for_digit_recognition
    elif emnist_task == 'autoencoder':
        mapping_fn = _reshape_for_autoencoder
    else:
        raise ValueError('emnist_task must be one of "digit_recognition" or '
                         '"autoencoder".')

    # Features are intentionally sorted lexicographically by key for consistency
    # across datasets.
    feature_dtypes = collections.OrderedDict(label=tff.TensorType(tf.int32),
                                             pixels=tff.TensorType(tf.float32,
                                                                   shape=(28,
                                                                          28)))

    @tff.tf_computation(tff.SequenceType(feature_dtypes))
    def preprocess_fn(dataset):
        return dataset.shuffle(shuffle_buffer_size).repeat(num_epochs).batch(
            batch_size,
            drop_remainder=False).map(mapping_fn,
                                      num_parallel_calls=num_parallel_calls)

    return preprocess_fn
Esempio n. 6
0
def create_preprocess_fn(
    num_epochs: int,
    batch_size: int,
    shuffle_buffer_size: int = NUM_EXAMPLES_PER_CLIENT,
    crop_shape: Tuple[int, int, int] = CIFAR_SHAPE,
    distort_image=False,
    num_parallel_calls: int = tf.data.experimental.AUTOTUNE) -> tff.Computation:
  """Creates a preprocessing function for CIFAR-100 client datasets.

  Args:
    num_epochs: An integer representing the number of epochs to repeat the
      client datasets.
    batch_size: An integer representing the batch size on clients.
    shuffle_buffer_size: An integer representing the shuffle buffer size on
      clients. If set to a number <= 1, no shuffling occurs.
    crop_shape: A tuple (crop_height, crop_width, num_channels) specifying the
      desired crop shape for pre-processing. This tuple cannot have elements
      exceeding (32, 32, 3), element-wise. The element in the last index should
      be set to 3 to maintain the RGB image structure of the elements.
    distort_image: A boolean indicating whether to perform preprocessing that
      includes image distortion, including random crops and flips.
    num_parallel_calls: An integer representing the number of parallel calls
      used when performing `tf.data.Dataset.map`.

  Returns:
    A `tff.Computation` performing the preprocessing described above.
  """
  if num_epochs < 1:
    raise ValueError('num_epochs must be a positive integer.')
  if shuffle_buffer_size <= 1:
    shuffle_buffer_size = 1

  # Features are intentionally sorted lexicographically by key for consistency
  # across datasets.
  feature_dtypes = collections.OrderedDict(
      coarse_label=tff.TensorType(tf.int64),
      image=tff.TensorType(tf.uint8, shape=(32, 32, 3)),
      label=tff.TensorType(tf.int64))

  image_map_fn = build_image_map(crop_shape, distort_image)

  @tff.tf_computation(tff.SequenceType(feature_dtypes))
  def preprocess_fn(dataset):
    return (
        dataset.shuffle(shuffle_buffer_size).repeat(num_epochs)
        # We map before batching to ensure that the cropping occurs
        # at an image level (eg. we do not perform the same crop on
        # every image within a batch)
        .map(image_map_fn,
             num_parallel_calls=num_parallel_calls).batch(batch_size))

  return preprocess_fn
Esempio n. 7
0
    def test_dynamic_lookup_table(self):
        @tff.tf_computation(tff.TensorType(shape=[None], dtype=tf.string),
                            tff.TensorType(shape=[None], dtype=tf.string))
        def comp(table_args, to_lookup):
            values = tf.range(tf.shape(table_args)[0])
            initializer = tf.lookup.KeyValueTensorInitializer(
                table_args, values)
            table = tf.lookup.StaticHashTable(initializer, default_value=101)
            return table.lookup(to_lookup)

        result = comp(tf.constant(['a', 'b', 'c']),
                      tf.constant(['a', 'z', 'c']))
        self.assertAllEqual(result, [0, 101, 2])
Esempio n. 8
0
    def test_lookup_table(self):
        @tff.tf_computation(tff.TensorType(shape=[None], dtype=tf.string),
                            tff.TensorType(shape=[], dtype=tf.string))
        def foo(table_args, to_lookup):
            values = tf.range(tf.shape(table_args)[0])
            initializer = tf.lookup.KeyValueTensorInitializer(
                table_args, values)
            table = tf.lookup.StaticHashTable(initializer, 100)
            return table.lookup(to_lookup)

        self.assertEqual(foo(tf.constant(['a', 'b']), 'a'), 0)
        self.assertEqual(foo(tf.constant(['d', 'e', 'f']), 'f'), 2)
        self.assertEqual(foo(tf.constant(['d', 'e', 'f', 'g', 'h', 'i']), 'i'),
                         5)
Esempio n. 9
0
    def test_reinitialize_dynamic_lookup_table(self):
        @tff.tf_computation(tff.TensorType(shape=[None], dtype=tf.string),
                            tff.TensorType(shape=[], dtype=tf.string))
        def comp(table_args, to_lookup):
            values = tf.range(tf.shape(table_args)[0])
            initializer = tf.lookup.KeyValueTensorInitializer(
                table_args, values)
            table = tf.lookup.StaticHashTable(initializer, default_value=101)
            return table.lookup(to_lookup)

        expected_zero = comp(tf.constant(['a', 'b', 'c']), tf.constant('a'))
        expected_three = comp(tf.constant(['a', 'b', 'c', 'd']),
                              tf.constant('d'))

        self.assertEqual(expected_zero, 0)
        self.assertEqual(expected_three, 3)
Esempio n. 10
0
    def test_bad_type_coercion_raises(self):
        tensor_type = tff.TensorType(shape=[None], dtype=tf.float32)

        @tff.tf_computation(tensor_type)
        def foo(x):
            # We will pass in a tensor which passes the TFF type check, but fails the
            # reshape.
            return tf.reshape(x, [])

        @tff.federated_computation(tff.type_at_clients(tensor_type))
        def map_foo_at_clients(x):
            return tff.federated_map(foo, x)

        @tff.federated_computation(tff.type_at_server(tensor_type))
        def map_foo_at_server(x):
            return tff.federated_map(foo, x)

        bad_tensor = tf.constant([1.] * 10, dtype=tf.float32)
        good_tensor = tf.constant([1.], dtype=tf.float32)
        # Ensure running this computation at both placements, or unplaced, still
        # raises.
        with self.assertRaises(Exception):
            foo(bad_tensor)
        with self.assertRaises(Exception):
            map_foo_at_server(bad_tensor)
        with self.assertRaises(Exception):
            map_foo_at_clients([bad_tensor] * 10)
        # We give the distributed runtime a chance to clean itself up, otherwise
        # workers may be getting SIGABRT while they are handling another exception,
        # causing the test infra to crash. Making a successful call ensures that
        # cleanup happens after failures have been handled.
        map_foo_at_clients([good_tensor] * 10)
Esempio n. 11
0
def create_preprocess_fn(
    num_epochs: int,
    batch_size: int,
    shuffle_buffer_size: int,
    mapping_fn: Callable[[Any], Any],
    num_parallel_calls: tf.Tensor = tf.data.experimental.AUTOTUNE
) -> tff.Computation:
    """Creates a preprocessing function for EMNIST client datasets.

  The preprocessing shuffles, repeats, batches, and then reshapes, using
  the `shuffle`, `repeat`, `batch`, and `map` attributes of a
  `tf.data.Dataset`, in that order.

  Args:
    num_epochs: An integer representing the number of epochs to repeat the
      client datasets.
    batch_size: An integer representing the batch size on clients.
    shuffle_buffer_size: An integer representing the shuffle buffer size on
      clients. If set to a number <= 1, no shuffling occurs.
    mapping_fn: A mapping function applied to EMNIST elements after shuffling,
      repeating, and batching occurs.
    num_parallel_calls: An integer representing the number of parallel calls
      used when performing `tf.data.Dataset.map`.

  Returns:
    A `tff.Computation` performing the preprocessing discussed above.
  """
    if num_epochs < 1:
        raise ValueError('num_epochs must be a positive integer.')
    if shuffle_buffer_size <= 1:
        shuffle_buffer_size = 1

    feature_dtypes = collections.OrderedDict(pixels=tff.TensorType(tf.float32,
                                                                   shape=(28,
                                                                          28)),
                                             label=tff.TensorType(tf.int32))

    @tff.tf_computation(tff.SequenceType(feature_dtypes))
    def preprocess_fn(dataset):
        return dataset.shuffle(shuffle_buffer_size).repeat(num_epochs).batch(
            batch_size,
            drop_remainder=False).map(mapping_fn,
                                      num_parallel_calls=num_parallel_calls)

    return preprocess_fn
Esempio n. 12
0
def make_integer_secure_sum(input_shape):
    if input_shape is None:
        member_type = tf.int32
    else:
        member_type = tff.TensorType(tf.int32, input_shape)

    @tff.federated_computation(tff.FederatedType(member_type, tff.CLIENTS))
    def secure_paillier_addition(x):
        return tff.federated_secure_sum(x, 64)

    return secure_paillier_addition
Esempio n. 13
0
def get_preprocessing_fn(image_size: int, batch_size: int, num_epochs: int,
                         max_elements: int,
                         shuffle_buffer_size: int) -> tff.Computation:
  """Creates a preprocessing function for federated training.

  Args:
    image_size: The height and width of images after preprocessing.
    batch_size: Batch size used for training.
    num_epochs: Number of training epochs.
    max_elements: The maximum number of elements taken from the dataset. It has
      to be a positive value or -1 (which means all elements are taken).
    shuffle_buffer_size: Buffer size used in shuffling.

  Returns:
    A `tff.Computation` that transforms the raw `tf.data.Dataset` of a client
    into a `tf.data.Dataset` that is ready for training.
  """
  _check_positive('image_size', image_size)
  _check_positive('batch_size', batch_size)
  _check_positive('num_epochs', num_epochs)
  _check_positive('shuffle_buffer_size', shuffle_buffer_size)
  if max_elements <= 0 and max_elements != -1:
    raise ValueError('Expected a positive value or -1 for `max_elements`, '
                     f'found {max_elements}.')

  feature_dtypes = collections.OrderedDict()
  feature_dtypes['image/decoded'] = tff.TensorType(
      dtype=tf.uint8, shape=[None, None, None])
  feature_dtypes['class'] = tff.TensorType(dtype=tf.int64, shape=[1])

  @tff.tf_computation(tff.SequenceType(feature_dtypes))
  def preprocessing_fn(dataset: tf.data.Dataset) -> tf.data.Dataset:
    dataset = dataset.map(
        functools.partial(_map_fn, is_training=True, image_size=image_size),
        num_parallel_calls=tf.data.experimental.AUTOTUNE).shuffle(
            shuffle_buffer_size).take(max_elements).repeat(num_epochs).batch(
                batch_size)
    return dataset

  return preprocessing_fn
Esempio n. 14
0
    def test_eval_fn_has_correct_type_signature(self):
        metrics_builder = lambda: [tf.keras.metrics.MeanSquaredError()]
        eval_fn = evaluation.build_centralized_evaluation(
            tff_model_fn, metrics_builder)
        actual_type = eval_fn.type_signature

        model_type = tff.FederatedType(
            tff.learning.ModelWeights(
                trainable=(
                    tff.TensorType(tf.float32, [1, 1]),
                    tff.TensorType(tf.float32, [1]),
                ),
                non_trainable=(),
            ), tff.SERVER)

        dataset_type = tff.FederatedType(
            tff.SequenceType(
                collections.OrderedDict(
                    x=tff.TensorType(tf.float32, [None, 1]),
                    y=tff.TensorType(tf.float32, [None, 1]))), tff.SERVER)

        metrics_type = tff.FederatedType(
            collections.OrderedDict(
                mean_squared_error=tff.TensorType(tf.float32),
                num_examples=tff.TensorType(tf.float32)), tff.SERVER)

        expected_type = tff.FunctionType(parameter=collections.OrderedDict(
            model_weights=model_type, centralized_dataset=dataset_type),
                                         result=metrics_type)

        actual_type.check_assignable_from(expected_type)
Esempio n. 15
0
    def test_secure_sum_larger_matrices(self, first_dim, second_dim):
        NUM_CLIENTS = 5
        shape = (first_dim, second_dim)
        input_tensor = np.ones(shape, dtype=np.int32)
        member_type = tff.TensorType(tf.int32, shape)

        @tff.federated_computation(tff.FederatedType(member_type, tff.CLIENTS))
        def secure_paillier_addition(x):
            return tff.federated_secure_sum(x, 64)

        with _install_executor(factory.local_paillier_executor_factory()):
            result = secure_paillier_addition([input_tensor] * NUM_CLIENTS)
        expected = input_tensor * NUM_CLIENTS
        np.testing.assert_almost_equal(result, expected)
    def test_iterative_process_type_signature(self):
        iterative_process = decay_iterative_process_builder.from_flags(
            input_spec=get_input_spec(),
            model_builder=model_builder,
            loss_builder=loss_builder,
            metrics_builder=metrics_builder)

        dummy_lr_callback = callbacks.create_reduce_lr_on_plateau(
            learning_rate=FLAGS.client_learning_rate,
            decay_factor=FLAGS.client_decay_factor,
            min_delta=FLAGS.min_delta,
            min_lr=FLAGS.min_lr,
            window_size=FLAGS.window_size,
            patience=FLAGS.patience)
        lr_callback_type = tff.framework.type_from_tensors(dummy_lr_callback)

        server_state_type = tff.FederatedType(
            adaptive_fed_avg.ServerState(model=tff.learning.ModelWeights(
                trainable=(tff.TensorType(tf.float32, [1, 1]),
                           tff.TensorType(tf.float32, [1])),
                non_trainable=()),
                                         optimizer_state=[tf.int64],
                                         client_lr_callback=lr_callback_type,
                                         server_lr_callback=lr_callback_type),
            tff.SERVER)

        self.assertEqual(
            iterative_process.initialize.type_signature,
            tff.FunctionType(parameter=None, result=server_state_type))

        dataset_type = tff.FederatedType(
            tff.SequenceType(
                collections.OrderedDict(
                    x=tff.TensorType(tf.float32, [None, 1]),
                    y=tff.TensorType(tf.float32, [None, 1]))), tff.CLIENTS)

        metrics_type = tff.FederatedType(
            collections.OrderedDict(
                mean_squared_error=tff.TensorType(tf.float32),
                loss=tff.TensorType(tf.float32)), tff.SERVER)
        output_type = collections.OrderedDict(before_training=metrics_type,
                                              during_training=metrics_type)

        expected_result_type = (server_state_type, output_type)
        expected_type = tff.FunctionType(parameter=(server_state_type,
                                                    dataset_type),
                                         result=expected_result_type)

        actual_type = iterative_process.next.type_signature
        self.assertTrue(actual_type.is_equivalent_to(expected_type))
Esempio n. 17
0
        def iterative_process_builder(
            model_fn: Callable[[], tff.learning.Model],
            client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None,
        ) -> tff.templates.IterativeProcess:

            factory = importance_aggregation_factory.ImportanceSamplingFactory(
                FLAGS.clients_per_round)
            weights_type = importance_aggregation_factory.weights_type_from_model_fn(
                model_fn)
            importance_aggregation_process = factory.create(
                value_type=weights_type,
                weight_type=tff.TensorType(tf.float32))

            return importance_schedule.build_fed_avg_process(
                model_fn=model_fn,
                client_optimizer_fn=client_optimizer_fn,
                client_lr=client_lr_schedule,
                server_optimizer_fn=server_optimizer_fn,
                server_lr=server_lr_schedule,
                aggregation_process=importance_aggregation_process)
Esempio n. 18
0
 async def _compute_reshape_on_tensor(self, tensor, output_shape):
     tensor_type = tensor.type_signature.member
     shape_type = type_conversions.infer_type(output_shape)
     reshaper_proto, reshaper_type = utils.materialize_computation_from_cache(
         paillier_comp.make_reshape_tensor,
         self._reshape_function_cache,
         arg_spec=(tensor_type, ),
         output_shape=output_shape)
     tensor_placement = tensor.type_signature.placement
     children = self._get_child_executors(tensor_placement)
     py_typecheck.check_len(tensor.internal_representation, len(children))
     reshaper_fns = await asyncio.gather(*[
         ex.create_value(reshaper_proto, reshaper_type) for ex in children
     ])
     reshaped_tensors = await asyncio.gather(*[
         ex.create_call(fn, arg) for ex, fn, arg in zip(
             children, reshaper_fns, tensor.internal_representation)
     ])
     output_tensor_spec = tff.FederatedType(
         tff.TensorType(tensor_type.dtype, output_shape), tensor_placement,
         tensor.type_signature.all_equal)
     return federated_resolving_strategy.FederatedResolvingStrategyValue(
         reshaped_tensors, output_tensor_spec)
Esempio n. 19
0
import numpy as np
import tensorflow as tf
import tensorflow_federated as tff

from federated_aggregations import paillier

NUM_CLIENTS = 5

paillier_factory = paillier.local_paillier_executor_factory(NUM_CLIENTS)
paillier_context = tff.framework.ExecutionContext(paillier_factory)
tff.framework.set_default_context(paillier_context)


@tff.federated_computation(
    tff.FederatedType(tff.TensorType(tf.int32, [2]), tff.CLIENTS),
    tff.TensorType(tf.int32))
def secure_paillier_addition(x, bitwidth):
    return tff.federated_secure_sum(x, bitwidth)


base = np.array([1, 2], np.int32)
x = [base + i for i in range(NUM_CLIENTS)]
result = secure_paillier_addition(x, 32)
print(result)
Esempio n. 20
0
                          dtype=np.float32),
            'y': np.array([source[1][i] for i in batch_samples], dtype=np.int32)})

    # add noise 0x-0.2x
    ratio = num * 0.05
    sum_agent = int(len(all_samples))
    index = 0
    for i in range(0, sum_agent):
        noiseHere = ratio * np.random.randn(28*28)
        output_sequence[int(i/BATCH_SIZE)]['x'][i % BATCH_SIZE] = checkRange(np.add(
            output_sequence[int(i/BATCH_SIZE)]['x'][i % BATCH_SIZE], noiseHere))
    return output_sequence


BATCH_TYPE = tff.NamedTupleType([
    ('x', tff.TensorType(tf.float32, [None, 784])),
    ('y', tff.TensorType(tf.int32, [None]))])

MODEL_TYPE = tff.NamedTupleType([
    ('weights', tff.TensorType(tf.float32, [784, 10])),
    ('bias', tff.TensorType(tf.float32, [10]))])


@tff.tf_computation(MODEL_TYPE, BATCH_TYPE)
def batch_loss(model, batch):
    predicted_y = tf.nn.softmax(tf.matmul(batch.x, model.weights) + model.bias)
    return -tf.reduce_mean(tf.reduce_sum(
        tf.one_hot(batch.y, 10) * tf.log(predicted_y), axis=[1]))


@tff.tf_computation(MODEL_TYPE, BATCH_TYPE, tf.float32)
    def run_one_round(server_state, federated_dataset, ids):
        """Orchestration logic for one round of computation.

    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`.

    Returns:
      A tuple of updated `ServerState` and the result of
      `tff.learning.Model.federated_output_computation`.
    """
        client_model = tff.federated_broadcast(server_state.model)
        client_round_num = tff.federated_broadcast(server_state.round_num)

        client_outputs = tff.federated_map(
            client_update_fn,
            (federated_dataset, client_model, client_round_num, ids))

        client_weight = client_outputs.client_weight
        client_id = client_outputs.client_id

        #LOSS SELECTION:
        # losses_at_server = tff.federated_collect(client_outputs.model_output)
        # weights_at_server = tff.federated_collect(client_weight)
        @computations.tf_computation
        def zeros_fn():
            return tf.zeros(shape=[total_clients, 1], dtype=tf.float32)

        zero = zeros_fn()

        at_server_type = tff.TensorType(shape=[total_clients, 1],
                                        dtype=tf.float32)
        # list_type = tff.SequenceType( tff.TensorType(dtype=tf.float32))
        client_output_type = client_update_fn.type_signature.result

        @computations.tf_computation(at_server_type, client_output_type)
        def accumulate_weight(u, t):
            value = t.client_weight
            index = t.client_id
            new_u = tf.tensor_scatter_nd_update(u, index, value)
            return new_u

        @computations.tf_computation(at_server_type, client_output_type)
        def accumulate_loss(u, t):
            value = tf.reshape(tf.math.reduce_sum(t.model_output['loss']),
                               shape=[1, 1])
            index = t.client_id
            new_u = tf.tensor_scatter_nd_update(u, index, value)
            return new_u

        # output_at_server= tff.federated_collect(client_outputs)

        weights_at_server = tff.federated_reduce(client_outputs, zero,
                                                 accumulate_weight)
        losses_at_server = tff.federated_reduce(client_outputs, zero,
                                                accumulate_loss)
        #losses_at_server = tff.federated_aggregate(client_outputs.model_output, zero, accumulate, merge, report)

        selected_clients_weights = tff.federated_map(
            zero_small_loss_clients, (losses_at_server, weights_at_server,
                                      server_state.effective_num_clients))

        # selected_clients_weights_at_client = tff.federated_broadcast(selected_clients_weights)

        selected_clients_weights_broadcast = tff.federated_broadcast(
            selected_clients_weights)

        selected_clients_weights_at_client = tff.federated_map(
            select_weight_fn, (selected_clients_weights_broadcast, ids))

        aggregation_output = aggregation_process.next(
            server_state.delta_aggregate_state, client_outputs.weights_delta,
            selected_clients_weights_at_client)

        # model_delta = tff.federated_mean(
        #     client_outputs.weights_delta, weight=client_weight)

        server_state = tff.federated_map(
            server_update_fn, (server_state, aggregation_output.result))

        aggregated_outputs = dummy_model.federated_output_computation(
            client_outputs.model_output)
        if aggregated_outputs.type_signature.is_struct():
            aggregated_outputs = tff.federated_zip(aggregated_outputs)

        return server_state, aggregated_outputs
Esempio n. 22
0
def build_triehh_process(
        possible_prefix_extensions: List[str],
        num_sub_rounds: int,
        max_num_prefixes: int,
        threshold: int,
        max_user_contribution: int,
        default_terminator: str = triehh_tf.DEFAULT_TERMINATOR):
    """Builds the TFF computations for heavy hitters discovery with TrieHH.

  TrieHH works by interactively keeping track of popular prefixes. In each
  round, the server broadcasts the popular prefixes it has
  discovered so far and the list of `possible_prefix_extensions` to a small
  fraction of selected clients. The select clients sample
  `max_user_contributions` words from their local datasets, and use them to vote
  on character extensions to the broadcasted popular prefixes. Client votes are
  accumulated across `num_sub_rounds` rounds, and then the top
  `max_num_prefixes` extensions get at least 'threshold' votes are used to
  extend the already discovered
  prefixes, and the extended prefixes are used in the next round. When an
  already discovered prefix is extended by `default_terminator` it is added to
  the list of discovered heavy hitters.

  Args:
    possible_prefix_extensions: A list containing all the possible extensions to
      learned prefixes. Each extensions must be a single character strings. This
      list should not contain the default_terminator.
    num_sub_rounds: The total number of sub rounds to be executed before
      decoding aggregated votes. Must be positive.
    max_num_prefixes: The maximum number of prefixes we can keep in the trie.
      Must be positive.
    threshold: The threshold for heavy hitters and discovered prefixes. Only
      those get at least `threshold` votes are discovered. Must be positive.
    max_user_contribution: The maximum number of examples a user can contribute.
      Must be positive.
    default_terminator: The end of sequence symbol.

  Returns:
    A `tff.templates.IterativeProcess`.

  Raises:
    ValueError: If possible_prefix_extensions contains default_terminator.
  """
    if default_terminator in possible_prefix_extensions:
        raise ValueError(
            'default_terminator should not appear in possible_prefix_extensions'
        )

    # Append `default_terminator` to `possible_prefix_extensions` to make sure it
    # is the last item in the list.
    possible_prefix_extensions.append(default_terminator)

    @tff.tf_computation
    def server_init_tf():
        return ServerState(
            discovered_heavy_hitters=tf.constant([], dtype=tf.string),
            heavy_hitters_frequencies=tf.constant([], dtype=tf.float64),
            discovered_prefixes=tf.constant([''], dtype=tf.string),
            round_num=tf.constant(0, dtype=tf.int32),
            accumulated_votes=tf.zeros(
                dtype=tf.int32,
                shape=[max_num_prefixes,
                       len(possible_prefix_extensions)]),
            accumulated_weights=tf.constant(0, dtype=tf.int32))

    # We cannot use server_init_tf.type_signature.result because the
    # discovered_* fields need to have [None] shapes, since they will grow over
    # time.
    server_state_type = (tff.to_type(
        ServerState(
            discovered_heavy_hitters=tff.TensorType(dtype=tf.string,
                                                    shape=[None]),
            heavy_hitters_frequencies=tff.TensorType(dtype=tf.float64,
                                                     shape=[None]),
            discovered_prefixes=tff.TensorType(dtype=tf.string, shape=[None]),
            round_num=tff.TensorType(dtype=tf.int32, shape=[]),
            accumulated_votes=tff.TensorType(
                dtype=tf.int32, shape=[None,
                                       len(possible_prefix_extensions)]),
            accumulated_weights=tff.TensorType(dtype=tf.int32, shape=[]),
        )))

    sub_round_votes_type = tff.TensorType(
        dtype=tf.int32,
        shape=[max_num_prefixes,
               len(possible_prefix_extensions)])
    sub_round_weight_type = tff.TensorType(dtype=tf.int32, shape=[])

    @tff.tf_computation(server_state_type, sub_round_votes_type,
                        sub_round_weight_type)
    def server_update_fn(server_state, sub_round_votes, sub_round_weight):
        return server_update(server_state,
                             tf.constant(possible_prefix_extensions),
                             sub_round_votes,
                             sub_round_weight,
                             num_sub_rounds=tf.constant(num_sub_rounds),
                             max_num_prefixes=tf.constant(max_num_prefixes),
                             threshold=tf.constant(threshold))

    tf_dataset_type = tff.SequenceType(tf.string)
    discovered_prefixes_type = tff.TensorType(dtype=tf.string, shape=[None])
    round_num_type = tff.TensorType(dtype=tf.int32, shape=[])

    @tff.tf_computation(tf_dataset_type, discovered_prefixes_type,
                        round_num_type)
    def client_update_fn(tf_dataset, discovered_prefixes, round_num):
        return client_update(tf_dataset, discovered_prefixes,
                             tf.constant(possible_prefix_extensions),
                             round_num, num_sub_rounds, max_num_prefixes,
                             max_user_contribution)

    federated_server_state_type = tff.FederatedType(server_state_type,
                                                    tff.SERVER)
    federated_dataset_type = tff.FederatedType(tf_dataset_type,
                                               tff.CLIENTS,
                                               all_equal=False)

    @tff.federated_computation(federated_server_state_type,
                               federated_dataset_type)
    def run_one_round(server_state, federated_dataset):
        """Orchestration logic for one round of TrieHH computation.

    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`.

    Returns:
      An updated `ServerState`
    """
        discovered_prefixes = tff.federated_broadcast(
            server_state.discovered_prefixes)
        round_num = tff.federated_broadcast(server_state.round_num)

        client_outputs = tff.federated_map(
            client_update_fn,
            (federated_dataset, discovered_prefixes, round_num))

        accumulated_votes = tff.federated_sum(client_outputs.client_votes)

        accumulated_weights = tff.federated_sum(client_outputs.client_weight)

        server_state = tff.federated_map(
            server_update_fn,
            (server_state, accumulated_votes, accumulated_weights))

        server_output = tff.federated_value([], tff.SERVER)

        return server_state, server_output

    return tff.templates.IterativeProcess(
        initialize_fn=tff.federated_computation(
            lambda: tff.federated_eval(server_init_tf, tff.SERVER)),
        next_fn=run_one_round)
Esempio n. 23
0
    def test_build_with_preprocess_funtion(self):
        test_dataset = tf.data.Dataset.range(5)
        client_datasets_type = tff.FederatedType(
            tff.SequenceType(test_dataset.element_spec), tff.CLIENTS)

        @tff.tf_computation(tff.SequenceType(test_dataset.element_spec))
        def preprocess_dataset(ds):
            def to_batch(x):
                return collections.OrderedDict(x=[float(x) * 1.0],
                                               y=[float(x) * 3.0 + 1.0])

            return ds.map(to_batch).repeat().batch(2).take(3)

        client_lr_callback = callbacks.create_reduce_lr_on_plateau(
            learning_rate=0.1,
            min_delta=0.5,
            window_size=2,
            decay_factor=1.0,
            cooldown=0)
        server_lr_callback = callbacks.create_reduce_lr_on_plateau(
            learning_rate=0.1,
            min_delta=0.5,
            window_size=2,
            decay_factor=1.0,
            cooldown=0)

        iterative_process = adaptive_fed_avg.build_fed_avg_process(
            _uncompiled_model_builder,
            client_lr_callback,
            server_lr_callback,
            client_optimizer_fn=tf.keras.optimizers.SGD,
            server_optimizer_fn=tf.keras.optimizers.SGD,
            dataset_preprocess_comp=preprocess_dataset)

        lr_callback_type = tff.framework.type_from_tensors(client_lr_callback)

        server_state_type = tff.FederatedType(
            adaptive_fed_avg.ServerState(model=tff.learning.ModelWeights(
                trainable=(tff.TensorType(tf.float32, [1, 1]),
                           tff.TensorType(tf.float32, [1])),
                non_trainable=()),
                                         optimizer_state=[tf.int64],
                                         client_lr_callback=lr_callback_type,
                                         server_lr_callback=lr_callback_type),
            tff.SERVER)

        self.assertEqual(
            iterative_process.initialize.type_signature,
            tff.FunctionType(parameter=None, result=server_state_type))

        metrics_type = tff.FederatedType(
            collections.OrderedDict(loss=tff.TensorType(tf.float32)),
            tff.SERVER)
        output_type = collections.OrderedDict(before_training=metrics_type,
                                              during_training=metrics_type)
        expected_result_type = (server_state_type, output_type)

        expected_type = tff.FunctionType(parameter=collections.OrderedDict(
            server_state=server_state_type,
            federated_dataset=client_datasets_type),
                                         result=expected_result_type)

        actual_type = iterative_process.next.type_signature
        self.assertEqual(actual_type,
                         expected_type,
                         msg='{s}\n!={t}'.format(s=actual_type,
                                                 t=expected_type))
Esempio n. 24
0
    def __init__(
        self,
        model_fn,
        m,
        n,
        j_max,
        importance_sampling,
        server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0),
        client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.1),
    ):
        """Builds the TFF computations for optimization using federated averaging.
        Args:
        model_fn: A no-arg function that returns a
          `simple_fedavg_tf.KerasModelWrapper`.
        server_optimizer_fn: A no-arg function that returns a
          `tf.keras.optimizers.Optimizer` for server update.
        client_optimizer_fn: A no-arg function that returns a
          `tf.keras.optimizers.Optimizer` for client update.
        Returns:
        A `tff.templates.IterativeProcess`.
        """

        dummy_model = model_fn()

        @tff.tf_computation
        def server_init_tf():
            model = model_fn()
            server_optimizer = server_optimizer_fn()
            _initialize_optimizer_vars(model, server_optimizer)
            return ServerState(model_weights=model.weights,
                               optimizer_state=server_optimizer.variables(),
                               round_num=0)

        server_state_type = server_init_tf.type_signature.result

        model_weights_type = server_state_type.model_weights

        @tff.tf_computation(server_state_type, model_weights_type.trainable)
        def server_update_fn(server_state, model_delta):
            model = model_fn()
            server_optimizer = server_optimizer_fn()
            _initialize_optimizer_vars(model, server_optimizer)
            return server_update(model, server_optimizer, server_state,
                                 model_delta)

        @tff.tf_computation(server_state_type)
        def server_message_fn(server_state):
            return build_server_broadcast_message(server_state)

        server_message_type = server_message_fn.type_signature.result
        tf_dataset_type = tff.SequenceType(dummy_model.input_spec)

        @tff.tf_computation(tf_dataset_type, server_message_type)
        def client_update_fn(tf_dataset, server_message):
            model = model_fn()
            client_optimizer = client_optimizer_fn()
            return client_update(model, tf_dataset, server_message,
                                 client_optimizer)

        federated_server_state_type = tff.FederatedType(
            server_state_type, tff.SERVER)
        federated_dataset_type = tff.FederatedType(tf_dataset_type,
                                                   tff.CLIENTS)

        @tff.tf_computation(
            tf.float32,
            tf.float32,
        )
        def scale(update_norm, sum_update_norms):
            if importance_sampling:
                return tf.minimum(
                    1., tf.divide(tf.multiply(update_norm, m),
                                  sum_update_norms))
            else:
                return tf.divide(m, n)

        @tff.federated_computation(tff.FederatedType(tf.float32, tff.CLIENTS),
                                   tff.FederatedType(tf.float32, tff.CLIENTS,
                                                     True))
        def scale_on_clients(update_norm, sum_update_norms):
            return tff.federated_map(scale, (update_norm, sum_update_norms))

        @tff.tf_computation(tf.float32)
        def create_prob_message(prob):
            def f1():
                return tf.stack([prob, 1.])

            def f2():
                return tf.constant([0., 0.])

            prob_message = tf.cond(tf.less(prob, 1), f1, f2)
            return prob_message

        @tff.federated_computation(tff.FederatedType(tf.float32, tff.CLIENTS))
        def create_prob_message_on_clients(prob):
            return tff.federated_map(create_prob_message, prob)

        @tff.tf_computation(tff.TensorType(tf.float32, (2, )))
        def compute_rescaling(prob_aggreg):
            rescaling_factor = (m - n + prob_aggreg[1]) / prob_aggreg[0]
            return rescaling_factor

        @tff.federated_computation(
            tff.FederatedType(tff.TensorType(tf.float32, (2, )), tff.SERVER))
        def compute_rescaling_on_master(prob_aggreg):
            return tff.federated_map(compute_rescaling, prob_aggreg)

        @tff.tf_computation(tf.float32, tf.float32)
        def rescale_prob(prob, rescaling_factor):
            return tf.minimum(1., tf.multiply(prob, rescaling_factor))

        @tff.federated_computation(tff.FederatedType(tf.float32, tff.CLIENTS),
                                   tff.FederatedType(tf.float32, tff.CLIENTS,
                                                     True))
        def rescale_prob_on_clients(rob, rescaling_factor):
            return tff.federated_map(rescale_prob, (rob, rescaling_factor))

        @tff.tf_computation(tf.float32)
        def compute_weights_is_fn(prob):
            def f1():
                return 1. / prob

            def f2():
                return 0.

            weight = tf.cond(tf.less(tf.random.uniform(()), prob), f1, f2)
            return weight

        @tff.federated_computation(tff.FederatedType(tf.float32, tff.CLIENTS))
        def compute_weights_is(prob):
            return tff.federated_map(compute_weights_is_fn, prob)

        @tff.federated_computation(
            tff.FederatedType(model_weights_type.trainable, tff.CLIENTS),
            tff.FederatedType(tf.float32, tff.CLIENTS))
        def compute_round_model_delta(weights_delta, weights_denom):
            return tff.federated_mean(weights_delta, weight=weights_denom)

        @tff.federated_computation(federated_server_state_type,
                                   tff.FederatedType(
                                       model_weights_type.trainable,
                                       tff.SERVER))
        def update_server_state(server_state, round_model_delta):
            return tff.federated_map(server_update_fn,
                                     (server_state, round_model_delta))

        @tff.federated_computation(tff.FederatedType(tf.float32, tff.CLIENTS),
                                   tff.FederatedType(tf.float32, tff.CLIENTS))
        def compute_loss_metric(model_output, weight_denom):
            return tff.federated_mean(model_output, weight=weight_denom)

        @tff.tf_computation(model_weights_type.trainable, tf.float32)
        def rescale_and_remove_fn(weights_delta, weights_is):
            return [
                tf.math.scalar_mul(weights_is, weights_layer_delta)
                for weights_layer_delta in weights_delta
            ]

        @tff.federated_computation(
            tff.FederatedType(model_weights_type.trainable, tff.CLIENTS),
            tff.FederatedType(tf.float32, tff.CLIENTS))
        def rescale_and_remove(weights_delta, weights_is):
            return tff.federated_map(rescale_and_remove_fn,
                                     (weights_delta, weights_is))

        @tff.federated_computation(federated_server_state_type,
                                   federated_dataset_type)
        def run_gradient_computation_round(server_state, federated_dataset):
            """Orchestration logic for one round of gradient computation.
            Args:
              server_state: A `ServerState`.
              federated_dataset: A federated `tf.data.Dataset` with placement
                `tff.CLIENTS`.
            Returns:
            A tuple of updated `tf.Tensor` of clients initial probability and `ClientOutput`.
            """
            server_message = tff.federated_map(server_message_fn, server_state)
            server_message_at_client = tff.federated_broadcast(server_message)

            client_outputs = tff.federated_map(
                client_update_fn,
                (federated_dataset, server_message_at_client))

            update_norm_sum_weighted = tff.federated_sum(
                client_outputs.update_norm_weighted)
            norm_sum_clients_weighted = tff.federated_broadcast(
                update_norm_sum_weighted)

            prob_init = scale_on_clients(client_outputs.update_norm_weighted,
                                         norm_sum_clients_weighted)
            return prob_init, client_outputs

        @tff.federated_computation(tff.FederatedType(tf.float32, tff.CLIENTS))
        def run_one_inner_loop_weights_computation(prob):
            """Orchestration logic for one round of computation.
            Args:
              prob: Probability of each client to communicate update.
            Returns:
            A tuple of updated `Probabilities` and `tf.float32` of rescaling factor.
            """

            prob_message = create_prob_message_on_clients(prob)
            prob_aggreg = tff.federated_sum(prob_message)
            rescaling_factor_master = compute_rescaling_on_master(prob_aggreg)
            rescaling_factor_clients = tff.federated_broadcast(
                rescaling_factor_master)
            prob = rescale_prob_on_clients(prob, rescaling_factor_clients)

            return prob, rescaling_factor_master

        @tff.federated_computation
        def server_init_tff():
            """Orchestration logic for server model initialization."""
            return tff.federated_value(server_init_tf(), tff.SERVER)

        def run_one_round(server_state, federated_dataset):
            """Orchestration logic for one round of computation.
            Args:
              server_state: A `ServerState`.
              federated_dataset: A federated `tf.data.Dataset` with placement
                `tff.CLIENTS`.
            Returns:
            A tuple of updated `ServerState` and `tf.Tensor` of average loss.
            """
            prob, client_outputs = run_gradient_computation_round(
                server_state, federated_dataset)

            if importance_sampling:
                for j in range(j_max):
                    prob, rescaling_factor = run_one_inner_loop_weights_computation(
                        prob)
                    if rescaling_factor <= 1:
                        break

            weight_denom = [
                client_output.client_weight for client_output in client_outputs
            ]
            weights_delta = [
                client_output.weights_delta for client_output in client_outputs
            ]

            # rescale weights based on sampling procedure
            weights_is = compute_weights_is(prob)
            weights_delta = rescale_and_remove(weights_delta, weights_is)

            round_model_delta = compute_round_model_delta(
                weights_delta, weight_denom)

            server_state = update_server_state(server_state, round_model_delta)

            model_output = [
                client_output.model_output for client_output in client_outputs
            ]
            round_loss_metric = compute_loss_metric(model_output, weight_denom)

            prob_numpy = []
            for p in prob:
                prob_numpy.append(p.numpy())

            return server_state, round_loss_metric, prob_numpy

        self.next = run_one_round
        self.initialize = server_init_tff
Esempio n. 25
0
def create_train_dataset_preprocess_fn(vocab: List[str],
                                       client_batch_size: int,
                                       client_epochs_per_round: int,
                                       max_seq_len: int,
                                       max_training_elements_per_user: int,
                                       max_batches_per_user=-1,
                                       max_shuffle_buffer_size=10000):
    """Creates preprocessing functions for stackoverflow data.

  This function returns a function which takes a dataset and returns a dataset,
  generally for mapping over a set of unprocessed client datasets during
  training.

  Args:
    vocab: Vocabulary which defines the embedding.
    client_batch_size: Integer representing batch size to use on the clients.
    client_epochs_per_round: Number of epochs for which to repeat train client
      dataset.
    max_seq_len: Integer determining shape of padded batches. Sequences will be
      padded up to this length, and sentences longer than `max_seq_len` will be
      truncated to this length.
    max_training_elements_per_user: Integer controlling the maximum number of
      elements to take per user. If -1, takes all elements for each user.
    max_batches_per_user: If set to a positive integer, the maximum number of
      batches in each client's dataset.
    max_shuffle_buffer_size: Maximum shuffle buffer size.

  Returns:
    Two functions, the first `preprocess_train` and the second
    `preprocess_val_and_test`, as described above.
  """
    if client_batch_size <= 0:
        raise ValueError(
            'client_batch_size must be a positive integer; you have '
            'passed {}'.format(client_batch_size))
    elif client_epochs_per_round == -1 and max_batches_per_user == -1:
        raise ValueError(
            'Argument client_epochs_per_round is set to -1. If this is'
            ' intended, then max_batches_per_user must be set to '
            'some positive integer.')
    elif max_seq_len <= 0:
        raise ValueError('max_seq_len must be a positive integer; you have '
                         'passed {}'.format(max_seq_len))
    elif max_training_elements_per_user < -1:
        raise ValueError(
            'max_training_elements_per_user must be an integer at '
            'least -1; you have passed {}'.format(
                max_training_elements_per_user))

    if (max_training_elements_per_user == -1
            or max_training_elements_per_user > max_shuffle_buffer_size):
        shuffle_buffer_size = max_shuffle_buffer_size
    else:
        shuffle_buffer_size = max_training_elements_per_user

    feature_dtypes = [
        ('creation_date', tf.string),
        ('title', tf.string),
        ('score', tf.int64),
        ('tags', tf.string),
        ('tokens', tf.string),
        ('type', tf.string),
    ]

    @tff.tf_computation(
        tff.SequenceType(
            tff.NamedTupleType([(name, tff.TensorType(dtype=dtype, shape=()))
                                for name, dtype in feature_dtypes])))
    def preprocess_train(dataset):
        to_ids = build_to_ids_fn(vocab, max_seq_len)
        dataset = dataset.take(max_training_elements_per_user)
        if shuffle_buffer_size > 0:
            logging.info('Adding shuffle with buffer size: %d',
                         shuffle_buffer_size)
            dataset = dataset.shuffle(shuffle_buffer_size)
        dataset = dataset.repeat(client_epochs_per_round)
        dataset = dataset.map(to_ids,
                              num_parallel_calls=tf.data.experimental.AUTOTUNE)
        dataset = batch_and_split(dataset, max_seq_len, client_batch_size)
        return dataset.take(max_batches_per_user)

    return preprocess_train
Esempio n. 26
0
    # all_samples = [i for i in range(int(num*(len(source[1])/NUM_AGENT)), int((num+1)*(len(source[1])/NUM_AGENT)))]

    for i in range(0, len(all_samples), BATCH_SIZE):
        batch_samples = all_samples[i:i + BATCH_SIZE]
        output_sequence.append({
            'x':
            np.array([source[0][i].flatten() / 255.0 for i in batch_samples],
                     dtype=np.float32),
            'y':
            np.array([source[1][i] for i in batch_samples], dtype=np.int32)
        })
    return output_sequence


BATCH_TYPE = tff.NamedTupleType([('x', tff.TensorType(tf.float32,
                                                      [None, 784])),
                                 ('y', tff.TensorType(tf.int32, [None]))])

MODEL_TYPE = tff.NamedTupleType([('weights',
                                  tff.TensorType(tf.float32, [784, 10])),
                                 ('bias', tff.TensorType(tf.float32, [10]))])


@tff.tf_computation(MODEL_TYPE, BATCH_TYPE)
def batch_loss(model, batch):
    predicted_y = tf.nn.softmax(tf.matmul(batch.x, model.weights) + model.bias)
    return -tf.reduce_mean(
        tf.reduce_sum(tf.one_hot(batch.y, 10) * tf.log(predicted_y), axis=[1]))


@tff.tf_computation(MODEL_TYPE, BATCH_TYPE, tf.float32)
Esempio n. 27
0
def build_triehh_process(possible_prefix_extensions: List[str],
                         num_sub_rounds: int,
                         max_num_heavy_hitters: int,
                         max_user_contribution: int,
                         default_terminator: str = '$'):
  """Builds the TFF computations for heavy hitters discovery with TrieHH.

  TrieHH works by interactively keeping track of popular prefixes. In each
  round, the server broadcasts the popular prefixes it has
  discovered so far and the list of `possible_prefix_extensions` to a small
  fraction of selected clients. The select clients sample
  `max_user_contributions` words from their local datasets, and use them to vote
  on character extensions to the broadcasted popular prefixes. Client votes are
  accumulated across `num_sub_rounds` rounds, and then the top
  `max_num_heavy_hitters` extensions are used to extend the already discovered
  prefixes, and the extended prefixes are used in the next round. When an
  already discovered prefix is extended by `default_terminator` it is added to
  the list of discovered heavy hitters.

  Args:
    possible_prefix_extensions: A list containing all the possible extensions to
      learned prefixes. Each extensions must be a single character strings.
    num_sub_rounds: The total number of sub rounds to be executed before
      decoding aggregated votes. Must be positive.
    max_num_heavy_hitters: The maximum number of discoverable heavy hitters.
      Must be positive.
    max_user_contribution: The maximum number of examples a user can contribute.
      Must be positive.
    default_terminator: The end of sequence symbol.

  Returns:
    A `tff.utils.IterativeProcess`.
  """

  @tff.tf_computation
  def server_init_tf():
    return ServerState(
        discovered_heavy_hitters=tf.constant([], dtype=tf.string),
        discovered_prefixes=tf.constant([''], dtype=tf.string),
        possible_prefix_extensions=tf.constant(
            possible_prefix_extensions, dtype=tf.string),
        round_num=tf.constant(0, dtype=tf.int32),
        accumulated_votes=tf.zeros(
            dtype=tf.int32,
            shape=[max_num_heavy_hitters,
                   len(possible_prefix_extensions)]))

  # We cannot use server_init_tf.type_signature.result because the
  # discovered_* fields need to have [None] shapes, since they will grow over
  # time.
  server_state_type = (
      tff.to_type(
          ServerState(
              discovered_heavy_hitters=tff.TensorType(
                  dtype=tf.string, shape=[None]),
              discovered_prefixes=tff.TensorType(dtype=tf.string, shape=[None]),
              possible_prefix_extensions=tff.TensorType(
                  dtype=tf.string, shape=[len(possible_prefix_extensions)]),
              round_num=tff.TensorType(dtype=tf.int32, shape=[]),
              accumulated_votes=tff.TensorType(
                  dtype=tf.int32, shape=[None,
                                         len(possible_prefix_extensions)]),
          )))

  sub_round_votes_type = tff.TensorType(
      dtype=tf.int32,
      shape=[max_num_heavy_hitters,
             len(possible_prefix_extensions)])

  @tff.tf_computation(server_state_type, sub_round_votes_type)
  @tf.function
  def server_update_fn(server_state, sub_round_votes):
    server_state = server_update(
        server_state,
        sub_round_votes,
        num_sub_rounds=tf.constant(num_sub_rounds),
        max_num_heavy_hitters=tf.constant(max_num_heavy_hitters),
        default_terminator=tf.constant(default_terminator, dtype=tf.string))
    return server_state

  tf_dataset_type = tff.SequenceType(tf.string)
  discovered_prefixes_type = tff.TensorType(dtype=tf.string, shape=[None])
  round_num_type = tff.TensorType(dtype=tf.int32, shape=[])

  @tff.tf_computation(tf_dataset_type, discovered_prefixes_type, round_num_type)
  @tf.function
  def client_update_fn(tf_dataset, discovered_prefixes, round_num):
    result = client_update(tf_dataset, discovered_prefixes,
                           tf.constant(possible_prefix_extensions), round_num,
                           num_sub_rounds, max_num_heavy_hitters,
                           max_user_contribution)
    return result

  federated_server_state_type = tff.FederatedType(server_state_type, tff.SERVER)
  federated_dataset_type = tff.FederatedType(
      tf_dataset_type, tff.CLIENTS, all_equal=False)

  @tff.federated_computation(federated_server_state_type,
                             federated_dataset_type)
  def run_one_round(server_state, federated_dataset):
    """Orchestration logic for one round of TrieHH computation.

    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`.

    Returns:
      An updated `ServerState`
    """
    discovered_prefixes = tff.federated_broadcast(
        server_state.discovered_prefixes)
    round_num = tff.federated_broadcast(server_state.round_num)

    client_outputs = tff.federated_map(
        client_update_fn,
        tff.federated_zip([federated_dataset, discovered_prefixes, round_num]))

    accumulated_votes = tff.federated_sum(client_outputs.client_votes)

    server_state = tff.federated_map(server_update_fn,
                                     (server_state, accumulated_votes))

    server_output = tff.federated_value([], tff.SERVER)

    return server_state, server_output

  return tff.utils.IterativeProcess(
      initialize_fn=tff.federated_computation(
          lambda: tff.federated_value(server_init_tf(), tff.SERVER)),
      next_fn=run_one_round)
Esempio n. 28
0
def build_federated_averaging_process_attacked(
    model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.1),
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0),
    model_update_aggregation_factory=None,
    client_update_tf=ClientExplicitBoosting(boost_factor=1.0)):
    """Builds the TFF computations for optimization using federated averaging with potentially malicious clients.

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`.
    client_optimizer_fn: A no-arg function that returns a
      `tf.keras.optimizers.Optimizer`, use during local client training.
    server_optimizer_fn: A no-arg function that returns a
      `tf.keras.optimizers.Optimizer`, use to apply updates to the global model.
    model_update_aggregation_factory: An optional
      `tff.aggregators.AggregationFactory` that contstructs
      `tff.templates.AggregationProcess` for aggregating the client model
      updates on the server. If `None`, uses a default constructed
      `tff.aggregators.MeanFactory`, creating a stateless mean aggregation.
    client_update_tf: a 'tf.function' computes the ClientOutput.

  Returns:
    A `tff.templates.IterativeProcess`.
  """
    with tf.Graph().as_default():
        dummy_model_for_metadata = model_fn()
        weights_type = tff.learning.framework.weights_type_from_model(
            dummy_model_for_metadata)

    if model_update_aggregation_factory is None:
        model_update_aggregation_factory = tff.aggregators.MeanFactory()

    if isinstance(model_update_aggregation_factory,
                  tff.aggregators.WeightedAggregationFactory):
        aggregation_process = model_update_aggregation_factory.create(
            weights_type.trainable, tff.TensorType(tf.float32))
    else:
        aggregation_process = model_update_aggregation_factory.create(
            weights_type.trainable)

    server_init = build_server_init_fn(model_fn, server_optimizer_fn,
                                       aggregation_process.initialize)
    server_state_type = server_init.type_signature.result.member
    server_update_fn = build_server_update_fn(model_fn, server_optimizer_fn,
                                              server_state_type,
                                              server_state_type.model)
    tf_dataset_type = tff.SequenceType(dummy_model_for_metadata.input_spec)

    client_update_fn = build_client_update_fn(model_fn, client_optimizer_fn,
                                              client_update_tf,
                                              tf_dataset_type,
                                              server_state_type.model)

    federated_server_state_type = tff.type_at_server(server_state_type)

    federated_dataset_type = tff.type_at_clients(tf_dataset_type)

    run_one_round_tff = build_run_one_round_fn_attacked(
        server_update_fn, client_update_fn, aggregation_process,
        dummy_model_for_metadata, federated_server_state_type,
        federated_dataset_type)

    return tff.templates.IterativeProcess(initialize_fn=server_init,
                                          next_fn=run_one_round_tff)
Esempio n. 29
0
beingm, this test is here to ensure that we don't break JAX support as we evolve
it to be more complete and feature-full.
"""

import collections
import itertools

from absl.testing import absltest
import jax
import numpy as np
import tensorflow_federated as tff

from tensorflow_federated.python.tests import jax_components

BATCH_TYPE = collections.OrderedDict([
    ('pixels', tff.TensorType(np.float32, (50, 784))),
    ('labels', tff.TensorType(np.int32, (50, )))
])

MODEL_TYPE = collections.OrderedDict([
    ('weights', tff.TensorType(np.float32, (784, 10))),
    ('bias', tff.TensorType(np.float32, (10, )))
])


def loss(model, batch):
    y = jax.nn.softmax(
        jax.numpy.add(jax.numpy.matmul(batch['pixels'], model['weights']),
                      model['bias']))
    targets = jax.nn.one_hot(jax.numpy.reshape(batch['labels'], -1), 10)
    return -jax.numpy.mean(jax.numpy.sum(targets * jax.numpy.log(y), axis=1))
def build_fed_avg_process(
    total_clients: int,
    effective_num_clients: int,
    model_fn: ModelBuilder,
    client_optimizer_fn: OptimizerBuilder,
    client_lr: Union[float, LRScheduleFn] = 0.1,
    server_optimizer_fn: OptimizerBuilder = tf.keras.optimizers.SGD,
    server_lr: Union[float, LRScheduleFn] = 1.0,
    client_weight_fn: Optional[ClientWeightFn] = None,
    aggregation_process: Optional[measured_process.MeasuredProcess] = None,
) -> tff.templates.IterativeProcess:
    """Builds the TFF computations for optimization using federated averaging.

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`.
    client_optimizer_fn: A function that accepts a `learning_rate` keyword
      argument and returns a `tf.keras.optimizers.Optimizer` instance.
    client_lr: A scalar learning rate or a function that accepts a float
      `round_num` argument and returns a learning rate.
    server_optimizer_fn: A function that accepts a `learning_rate` argument and
      returns a `tf.keras.optimizers.Optimizer` instance.
    server_lr: A scalar learning rate or a function that accepts a float
      `round_num` argument and returns a learning rate.
    client_weight_fn: Optional function that takes the output of
      `model.report_local_outputs` and returns a tensor that provides the weight
      in the federated average of model deltas. If not provided, the default is
      the total number of examples processed on device.

  Returns:
    A `tff.templates.IterativeProcess`.
  """

    client_lr_schedule = client_lr
    if not callable(client_lr_schedule):
        client_lr_schedule = lambda round_num: client_lr

    server_lr_schedule = server_lr
    if not callable(server_lr_schedule):
        server_lr_schedule = lambda round_num: server_lr

    with tf.Graph().as_default():
        dummy_model = model_fn()
        model_weights_type = model_utils.weights_type_from_model(dummy_model)
        dummy_optimizer = server_optimizer_fn()
        _initialize_optimizer_vars(dummy_model, dummy_optimizer)
        optimizer_variable_type = type_conversions.type_from_tensors(
            dummy_optimizer.variables())

    if aggregation_process is None:
        aggregation_process = build_stateless_mean(
            model_delta_type=model_weights_type.trainable)
    if not _is_valid_aggregation_process(aggregation_process):
        raise ProcessTypeError(
            'aggregation_process type signature does not conform to expected '
            'signature (<state@S, input@C> -> <state@S, result@S, measurements@S>).'
            ' Got: {t}'.format(t=aggregation_process.next.type_signature))

    initialize_computation = build_server_init_fn(
        model_fn=model_fn,
        effective_num_clients=effective_num_clients,
        # Initialize with the learning rate for round zero.
        server_optimizer_fn=lambda: server_optimizer_fn(server_lr_schedule(0)),
        aggregation_process=aggregation_process)

    # server_state_type = initialize_computation.type_signature.result
    # model_weights_type = server_state_type.model
    round_num_type = tf.float32

    tf_dataset_type = tff.SequenceType(dummy_model.input_spec)
    model_input_type = tff.SequenceType(dummy_model.input_spec)

    client_losses_at_server_type = tff.TensorType(dtype=tf.float32,
                                                  shape=[total_clients, 1])
    clients_weights_at_server_type = tff.TensorType(dtype=tf.float32,
                                                    shape=[total_clients, 1])

    aggregation_state = aggregation_process.initialize.type_signature.result.member

    server_state_type = ServerState(
        model=model_weights_type,
        optimizer_state=optimizer_variable_type,
        round_num=round_num_type,
        effective_num_clients=tf.int32,
        delta_aggregate_state=aggregation_state,
    )

    # @computations.tf_computation(clients_weights_type)
    # def get_zero_weights_all_clients(weights):
    #   return tf.zeros_like(weights, dtype=tf.float32)

    ######################################################
    # def federated_output(local_outputs):
    #   return federated_aggregate_keras_metric(self.get_metrics(), local_outputs)

    # federated_output_computation = computations.federated_computation(
    #       federated_output, federated_local_outputs_type)

    single_id_type = tff.TensorType(dtype=tf.int32, shape=[1, 1])

    @tff.tf_computation(model_input_type, model_weights_type, round_num_type,
                        single_id_type)
    def client_update_fn(tf_dataset, initial_model_weights, round_num,
                         client_id):
        client_lr = client_lr_schedule(round_num)
        client_optimizer = client_optimizer_fn(client_lr)
        client_update = create_client_update_fn()
        return client_update(model_fn(), tf_dataset, initial_model_weights,
                             client_optimizer, client_id, client_weight_fn)

    @tff.tf_computation(server_state_type, model_weights_type.trainable)
    def server_update_fn(server_state, model_delta):
        model = model_fn()
        server_lr = server_lr_schedule(server_state.round_num)
        server_optimizer = server_optimizer_fn(server_lr)
        # We initialize the server optimizer variables to avoid creating them
        # within the scope of the tf.function server_update.
        _initialize_optimizer_vars(model, server_optimizer)
        return server_update(model, server_optimizer, server_state,
                             model_delta)

    id_type = tff.TensorType(shape=[1, 1], dtype=tf.int32)

    @tff.tf_computation(clients_weights_at_server_type, id_type)
    def select_weight_fn(clients_weights, local_id):
        return select_weight(clients_weights, local_id)

    @tff.tf_computation(client_losses_at_server_type,
                        clients_weights_at_server_type, tf.int32)
    def zero_small_loss_clients(losses_at_server, weights_at_server,
                                effective_num_clients):
        """Receives losses and returns participating clients.

    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`.

    Returns:
      A tuple of updated `ServerState` and the result of
      `tff.learning.Model.federated_output_computation`.
    """
        return redefine_client_weight(losses_at_server, weights_at_server,
                                      effective_num_clients)

    # @tff.tf_computation(client_losses_type)
    # def dataset_to_tensor_fn(dataset):
    #   return dataset_to_tensor(dataset)
    @tff.federated_computation(tff.FederatedType(server_state_type,
                                                 tff.SERVER),
                               tff.FederatedType(tf_dataset_type, tff.CLIENTS),
                               tff.FederatedType(id_type, tff.CLIENTS))
    def run_one_round(server_state, federated_dataset, ids):
        """Orchestration logic for one round of computation.

    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`.

    Returns:
      A tuple of updated `ServerState` and the result of
      `tff.learning.Model.federated_output_computation`.
    """
        client_model = tff.federated_broadcast(server_state.model)
        client_round_num = tff.federated_broadcast(server_state.round_num)

        client_outputs = tff.federated_map(
            client_update_fn,
            (federated_dataset, client_model, client_round_num, ids))

        client_weight = client_outputs.client_weight
        client_id = client_outputs.client_id

        #LOSS SELECTION:
        # losses_at_server = tff.federated_collect(client_outputs.model_output)
        # weights_at_server = tff.federated_collect(client_weight)
        @computations.tf_computation
        def zeros_fn():
            return tf.zeros(shape=[total_clients, 1], dtype=tf.float32)

        zero = zeros_fn()

        at_server_type = tff.TensorType(shape=[total_clients, 1],
                                        dtype=tf.float32)
        # list_type = tff.SequenceType( tff.TensorType(dtype=tf.float32))
        client_output_type = client_update_fn.type_signature.result

        @computations.tf_computation(at_server_type, client_output_type)
        def accumulate_weight(u, t):
            value = t.client_weight
            index = t.client_id
            new_u = tf.tensor_scatter_nd_update(u, index, value)
            return new_u

        @computations.tf_computation(at_server_type, client_output_type)
        def accumulate_loss(u, t):
            value = tf.reshape(tf.math.reduce_sum(t.model_output['loss']),
                               shape=[1, 1])
            index = t.client_id
            new_u = tf.tensor_scatter_nd_update(u, index, value)
            return new_u

        # output_at_server= tff.federated_collect(client_outputs)

        weights_at_server = tff.federated_reduce(client_outputs, zero,
                                                 accumulate_weight)
        losses_at_server = tff.federated_reduce(client_outputs, zero,
                                                accumulate_loss)
        #losses_at_server = tff.federated_aggregate(client_outputs.model_output, zero, accumulate, merge, report)

        selected_clients_weights = tff.federated_map(
            zero_small_loss_clients, (losses_at_server, weights_at_server,
                                      server_state.effective_num_clients))

        # selected_clients_weights_at_client = tff.federated_broadcast(selected_clients_weights)

        selected_clients_weights_broadcast = tff.federated_broadcast(
            selected_clients_weights)

        selected_clients_weights_at_client = tff.federated_map(
            select_weight_fn, (selected_clients_weights_broadcast, ids))

        aggregation_output = aggregation_process.next(
            server_state.delta_aggregate_state, client_outputs.weights_delta,
            selected_clients_weights_at_client)

        # model_delta = tff.federated_mean(
        #     client_outputs.weights_delta, weight=client_weight)

        server_state = tff.federated_map(
            server_update_fn, (server_state, aggregation_output.result))

        aggregated_outputs = dummy_model.federated_output_computation(
            client_outputs.model_output)
        if aggregated_outputs.type_signature.is_struct():
            aggregated_outputs = tff.federated_zip(aggregated_outputs)

        return server_state, aggregated_outputs

    # @tff.federated_computation
    # def initialize_fn():
    #   return tff.federated_value(server_init_tf(), tff.SERVER)

    return tff.templates.IterativeProcess(initialize_fn=initialize_computation,
                                          next_fn=run_one_round)