def test_invoke_returns_result_with_tf_computation(self):
    make_10 = tensorflow_computation.tf_computation(lambda: tf.constant(10))
    add_one = tensorflow_computation.tf_computation(lambda x: tf.add(x, 1),
                                                    tf.int32)

    @tensorflow_computation.tf_computation
    def add_one_with_v1(x):
      v1 = tf.Variable(1, name='v1')
      return x + v1

    @tensorflow_computation.tf_computation
    def add_one_with_v2(x):
      v2 = tf.Variable(1, name='v2')
      return x + v2

    @tensorflow_computation.tf_computation
    def foo():
      zero = tf.Variable(0, name='zero')
      ten = tf.Variable(make_10())
      return (add_one_with_v2(add_one_with_v1(add_one(make_10()))) + zero +
              ten - ten)

    with tf.compat.v1.Graph().as_default() as graph:
      context = tensorflow_computation_context.TensorFlowComputationContext(
          graph, tf.constant('bogus_token'))

    self.assertEqual(foo.type_signature.compact_representation(), '( -> int32)')
    x = context.invoke(foo, None)

    with tf.compat.v1.Session(graph=graph) as sess:
      if context.init_ops:
        sess.run(context.init_ops)
      result = sess.run(x)
    self.assertEqual(result, 13)
Exemplo n.º 2
0
 def test_measured_process_output_as_state_raises(self):
   empty_output = lambda: MeasuredProcessOutput((), (), ())
   initialize_fn = tensorflow_computation.tf_computation(empty_output)
   next_fn = tensorflow_computation.tf_computation(
       initialize_fn.type_signature.result)(lambda state: empty_output())
   with self.assertRaises(errors.TemplateStateNotAssignableError):
     measured_process.MeasuredProcess(initialize_fn, next_fn)
Exemplo n.º 3
0
 def test_construction_with_empty_state_does_not_raise(self):
   initialize_fn = tensorflow_computation.tf_computation()(lambda: ())
   next_fn = tensorflow_computation.tf_computation(
       ())(lambda x: MeasuredProcessOutput(x, (), ()))
   try:
     measured_process.MeasuredProcess(initialize_fn, next_fn)
   except:  # pylint: disable=bare-except
     self.fail('Could not construct an MeasuredProcess with empty state.')
Exemplo n.º 4
0
 def get_bounds(state):
     cast_fn = tensorflow_computation.tf_computation(
         lambda x: tf.cast(x, bound_dtype))
     upper_bound = intrinsics.federated_map(cast_fn, process.report(state))
     lower_bound = intrinsics.federated_map(
         tensorflow_computation.tf_computation(lambda x: x * -1.0),
         upper_bound)
     return upper_bound, lower_bound
 def test_construction_with_empty_state_does_not_raise(self):
     initialize_fn = tensorflow_computation.tf_computation()(lambda: ())
     next_fn = tensorflow_computation.tf_computation(())(lambda x: (x, 1.0))
     try:
         iterative_process.IterativeProcess(initialize_fn, next_fn)
     except:  # pylint: disable=bare-except
         self.fail(
             'Could not construct an IterativeProcess with empty state.')
Exemplo n.º 6
0
 def update_state(state, value_min, value_max):
     value_min = intrinsics.federated_map(
         tensorflow_computation.tf_computation(
             lambda x: tf.cast(x, min_dtype)), value_min)
     value_max = intrinsics.federated_map(
         tensorflow_computation.tf_computation(
             lambda x: tf.cast(x, max_dtype)), value_max)
     return intrinsics.federated_zip(
         (upper_bound_process.next(state[0], value_max),
          lower_bound_process.next(state[1], value_min)))
Exemplo n.º 7
0
 def next_fn(state, value):
     state = intrinsics.federated_map(
         tensorflow_computation.tf_computation(lambda x: x + 1), state)
     result = intrinsics.federated_map(
         tensorflow_computation.tf_computation(
             lambda x: tf.nest.map_structure(lambda y: y + 1, x)),
         intrinsics.federated_sum(value))
     measurements = intrinsics.federated_value(MEASUREMENT_CONSTANT,
                                               placements.SERVER)
     return measured_process.MeasuredProcessOutput(
         state, result, measurements)
Exemplo n.º 8
0
    def test_invoke_with_typed_fn(self):
        def foo(x):
            return x > 10

        foo = tensorflow_computation.tf_computation(foo, tf.int32)
        self.assertEqual(foo.type_signature.compact_representation(),
                         '(int32 -> bool)')
Exemplo n.º 9
0
    def test_invoke_with_no_arg_fn(self):
        def foo():
            return 10

        foo = tensorflow_computation.tf_computation(foo)
        self.assertEqual(foo.type_signature.compact_representation(),
                         '( -> int32)')
Exemplo n.º 10
0
    def one_round_computation(examples):
        """The TFF computation to compute the aggregated IBLT sketch."""
        if secure_sum_bitwidth is not None:
            # Use federated secure modular sum for IBLT sketches, because IBLT
            # sketches are decoded by taking modulo over the field size.
            sketch_sum_fn = secure_modular_sum
            count_sum_fn = secure_sum
        else:
            sketch_sum_fn = intrinsics.federated_sum
            count_sum_fn = intrinsics.federated_sum
        round_timestamp = intrinsics.federated_eval(
            tensorflow_computation.tf_computation(
                lambda: tf.cast(tf.timestamp(), tf.int64)), placements.SERVER)
        clients = count_sum_fn(
            intrinsics.federated_value(1, placements.CLIENTS))
        sketch, count_tensor = intrinsics.federated_map(
            compute_sketch, examples)
        sketch = sketch_sum_fn(sketch)
        count_tensor = count_sum_fn(count_tensor)

        (heavy_hitters, heavy_hitters_unique_counts, heavy_hitters_counts,
         num_not_decoded) = intrinsics.federated_map(decode_heavy_hitters,
                                                     (sketch, count_tensor))
        server_output = intrinsics.federated_zip(
            ServerOutput(
                clients=clients,
                heavy_hitters=heavy_hitters,
                heavy_hitters_unique_counts=heavy_hitters_unique_counts,
                heavy_hitters_counts=heavy_hitters_counts,
                num_not_decoded=num_not_decoded,
                round_timestamp=round_timestamp))
        return server_output
Exemplo n.º 11
0
    def test_roundtrip(self):
        add = tensorflow_computation.tf_computation(lambda x, y: x + y)
        server_data_type = computation_types.at_server(tf.int32)
        client_data_type = computation_types.at_clients(tf.int32)

        @federated_computation.federated_computation(server_data_type,
                                                     client_data_type)
        def add_server_number_plus_one(server_number, client_numbers):
            one = intrinsics.federated_value(1, placements.SERVER)
            server_context = intrinsics.federated_map(add,
                                                      (one, server_number))
            client_context = intrinsics.federated_broadcast(server_context)
            return intrinsics.federated_map(add,
                                            (client_context, client_numbers))

        bf = form_utils.get_broadcast_form_for_computation(
            add_server_number_plus_one)
        self.assertEqual(bf.server_data_label, 'server_number')
        self.assertEqual(bf.client_data_label, 'client_numbers')
        type_test_utils.assert_types_equivalent(
            bf.compute_server_context.type_signature,
            computation_types.FunctionType(tf.int32, (tf.int32, )))
        self.assertEqual(2, bf.compute_server_context(1)[0])
        type_test_utils.assert_types_equivalent(
            bf.client_processing.type_signature,
            computation_types.FunctionType(((tf.int32, ), tf.int32), tf.int32))
        self.assertEqual(3, bf.client_processing((1, ), 2))

        round_trip_comp = form_utils.get_computation_for_broadcast_form(bf)
        type_test_utils.assert_types_equivalent(
            round_trip_comp.type_signature,
            add_server_number_plus_one.type_signature)
        # 2 (server data) + 1 (constant in comp) + 2 (client data) = 5 (output)
        self.assertEqual([5, 6, 7], round_trip_comp(2, [2, 3, 4]))
Exemplo n.º 12
0
 def init_fn():
     specs = weight_tensor_specs.trainable
     optimizer_state = intrinsics.federated_eval(
         tensorflow_computation.tf_computation(
             lambda: optimizer.initialize(specs)), placements.SERVER)
     aggregator_state = full_gradient_aggregator.initialize()
     return intrinsics.federated_zip((optimizer_state, aggregator_state))
Exemplo n.º 13
0
 def test_next_return_namedtuple_raises(self):
   measured_process_output = collections.namedtuple(
       'MeasuredProcessOutput', ['state', 'result', 'measurements'])
   namedtuple_next_fn = tensorflow_computation.tf_computation(
       tf.int32)(lambda state: measured_process_output(state, (), ()))
   with self.assertRaises(errors.TemplateNotMeasuredProcessOutputError):
     measured_process.MeasuredProcess(test_initialize_fn, namedtuple_next_fn)
Exemplo n.º 14
0
    def test_roundtrip_no_broadcast(self):
        add_five = tensorflow_computation.tf_computation(lambda x: x + 5)
        server_data_type = computation_types.at_server(())
        client_data_type = computation_types.at_clients(tf.int32)

        @federated_computation.federated_computation(server_data_type,
                                                     client_data_type)
        def add_five_at_clients(naught_at_server, client_numbers):
            del naught_at_server
            return intrinsics.federated_map(add_five, client_numbers)

        bf = form_utils.get_broadcast_form_for_computation(add_five_at_clients)
        self.assertEqual(bf.server_data_label, 'naught_at_server')
        self.assertEqual(bf.client_data_label, 'client_numbers')
        type_test_utils.assert_types_equivalent(
            bf.compute_server_context.type_signature,
            computation_types.FunctionType((), ()))
        type_test_utils.assert_types_equivalent(
            bf.client_processing.type_signature,
            computation_types.FunctionType(((), tf.int32), tf.int32))
        self.assertEqual(6, bf.client_processing((), 1))

        round_trip_comp = form_utils.get_computation_for_broadcast_form(bf)
        type_test_utils.assert_types_equivalent(
            round_trip_comp.type_signature, add_five_at_clients.type_signature)
        self.assertEqual([10, 11, 12], round_trip_comp((), [5, 6, 7]))
Exemplo n.º 15
0
 def update_state(state, value_min, value_max):
     abs_max_fn = tensorflow_computation.tf_computation(
         lambda x, y: tf.cast(tf.maximum(tf.abs(x), tf.abs(y)),
                              expected_dtype))
     abs_value_max = intrinsics.federated_map(abs_max_fn,
                                              (value_min, value_max))
     return process.next(state, abs_value_max)
 def next_fn(strings, val):
   new_state_fn = tensorflow_computation.tf_computation()(
       lambda s: tf.concat([s, tf.constant(['abc'])], axis=0))
   return MeasuredProcessOutput(
       intrinsics.federated_map(new_state_fn, strings),
       intrinsics.federated_sum(val),
       intrinsics.federated_value(1, placements.SERVER))
Exemplo n.º 17
0
    def _create_next_fn(self, inner_agg_next, state_type, value_type):

        modular_clip_by_value_fn = tensorflow_computation.tf_computation(
            _modular_clip_by_value)

        @federated_computation.federated_computation(
            state_type, computation_types.at_clients(value_type))
        def next_fn(state, value):
            clip_lower = intrinsics.federated_value(self._clip_range_lower,
                                                    placements.SERVER)
            clip_upper = intrinsics.federated_value(self._clip_range_upper,
                                                    placements.SERVER)

            # Modular clip values before aggregation.
            clipped_value = intrinsics.federated_map(
                modular_clip_by_value_fn,
                (value, intrinsics.federated_broadcast(clip_lower),
                 intrinsics.federated_broadcast(clip_upper)))

            inner_agg_output = inner_agg_next(state, clipped_value)

            # Clip the aggregate to the same range again (not considering summands).
            clipped_agg_output_result = intrinsics.federated_map(
                modular_clip_by_value_fn,
                (inner_agg_output.result, clip_lower, clip_upper))

            measurements = collections.OrderedDict(
                modclip=inner_agg_output.measurements)

            return measured_process.MeasuredProcessOutput(
                state=inner_agg_output.state,
                result=clipped_agg_output_result,
                measurements=intrinsics.federated_zip(measurements))

        return next_fn
Exemplo n.º 18
0
 def test_next_state_not_assignable_tuple_result(self):
     float_next_fn = tensorflow_computation.tf_computation(
         tf.float32,
         tf.float32)(lambda state, x: (tf.cast(state, tf.float32), x))
     with self.assertRaises(errors.TemplateStateNotAssignableError):
         estimation_process.EstimationProcess(test_initialize_fn,
                                              float_next_fn, test_report_fn)
Exemplo n.º 19
0
 def test_map_estimate_not_assignable(self):
     map_fn = tensorflow_computation.tf_computation(
         tf.int32)(lambda estimate: estimate)
     process = estimation_process.EstimationProcess(test_initialize_fn,
                                                    test_next_fn,
                                                    test_report_fn)
     with self.assertRaises(estimation_process.EstimateNotAssignableError):
         process.map(map_fn)
Exemplo n.º 20
0
    def test_takes_tuple_typed(self):
        @tf.function
        def foo(t):
            return t[0] + t[1]

        foo = tensorflow_computation.tf_computation(foo, (tf.int32, tf.int32))
        self.assertEqual(foo.type_signature.compact_representation(),
                         '(<int32,int32> -> int32)')
Exemplo n.º 21
0
def _run_in_tf_computation(optimizer, spec):
    weights = tf.nest.map_structure(lambda s: tf.ones(s.shape, s.dtype), spec)
    gradients = tf.nest.map_structure(lambda s: tf.ones(s.shape, s.dtype),
                                      spec)
    init_fn = tensorflow_computation.tf_computation(
        lambda: optimizer.initialize(spec))
    next_fn = tensorflow_computation.tf_computation(optimizer.next)

    state = init_fn()
    state_history = [state]
    weights_history = [weights]
    for _ in range(3):
        state, weights = next_fn(state, weights, gradients)
        state_history.append(state)
        weights_history.append(weights)

    return state_history, weights_history
Exemplo n.º 22
0
 def next_fn(state, weights, updates):
     new_weights = intrinsics.federated_map(
         tensorflow_computation.tf_computation(lambda x, y: x + y),
         (weights.trainable, updates))
     new_weights = intrinsics.federated_zip(
         model_utils.ModelWeights(new_weights, ()))
     return measured_process.MeasuredProcessOutput(state, new_weights,
                                                   empty_at_server())
  def test_non_federated_init_next_raises(self):
    initialize_fn = tensorflow_computation.tf_computation(lambda: 0)

    @tensorflow_computation.tf_computation(tf.int32, tf.float32)
    def next_fn(state, val):
      return MeasuredProcessOutput(state, val, ())

    with self.assertRaises(aggregation_process.AggregationNotFederatedError):
      aggregation_process.AggregationProcess(initialize_fn, next_fn)
Exemplo n.º 24
0
 def next_comp(state, value):
     return measured_process.MeasuredProcessOutput(
         state=intrinsics.federated_map(_add_one, state),
         result=intrinsics.federated_broadcast(value),
         # Arbitrary metrics for testing.
         measurements=intrinsics.federated_map(
             tensorflow_computation.tf_computation(
                 lambda v: tf.linalg.global_norm(tf.nest.flatten(v)) + 3.0),
             value))
Exemplo n.º 25
0
def _create_test_iterative_process(state_type, state_init):

  @tensorflow_computation.tf_computation(state_type)
  def next_fn(state):
    return state

  return iterative_process.IterativeProcess(
      initialize_fn=tensorflow_computation.tf_computation(
          lambda: tf.constant(state_init)),
      next_fn=next_fn)
Exemplo n.º 26
0
def _test_map_reduce_form_computations():
    @tensorflow_computation.tf_computation
    def initialize():
        return tf.constant(0)

    @tensorflow_computation.tf_computation(tf.int32)
    def prepare(server_state):
        del server_state  # Unused
        return tf.constant(1.0)

    @tensorflow_computation.tf_computation(
        computation_types.SequenceType(tf.float32), tf.float32)
    def work(client_data, client_input):
        del client_data  # Unused
        del client_input  # Unused
        return True, [], [], []

    @tensorflow_computation.tf_computation
    def zero():
        return tf.constant(0), tf.constant(0)

    @tensorflow_computation.tf_computation((tf.int32, tf.int32), tf.bool)
    def accumulate(accumulator, client_update):
        del accumulator  # Unused
        del client_update  # Unused
        return tf.constant(1), tf.constant(1)

    @tensorflow_computation.tf_computation((tf.int32, tf.int32),
                                           (tf.int32, tf.int32))
    def merge(accumulator1, accumulator2):
        del accumulator1  # Unused
        del accumulator2  # Unused
        return tf.constant(1), tf.constant(1)

    @tensorflow_computation.tf_computation(tf.int32, tf.int32)
    def report(accumulator):
        del accumulator  # Unused
        return tf.constant(1.0)

    unit_comp = tensorflow_computation.tf_computation(lambda: [])
    bitwidth = unit_comp
    max_input = unit_comp
    modulus = unit_comp
    unit_type = computation_types.to_type([])

    @tensorflow_computation.tf_computation(
        tf.int32, (tf.float32, unit_type, unit_type, unit_type))
    def update(server_state, global_update):
        del server_state  # Unused
        del global_update  # Unused
        return tf.constant(1), []

    return (initialize, prepare, work, zero, accumulate, merge, report,
            bitwidth, max_input, modulus, update)
Exemplo n.º 27
0
    def test_invoke_with_polymorphic_lambda(self):
        foo = lambda x: x > 10
        foo = tensorflow_computation.tf_computation(foo)

        concrete_fn = foo.fn_for_argument_type(
            computation_types.TensorType(tf.int32))
        self.assertEqual(concrete_fn.type_signature.compact_representation(),
                         '(int32 -> bool)')
        concrete_fn = foo.fn_for_argument_type(
            computation_types.TensorType(tf.float32))
        self.assertEqual(concrete_fn.type_signature.compact_representation(),
                         '(float32 -> bool)')
    def hierarchical_histogram_computation(federated_client_data):
        round_timestamp = intrinsics.federated_eval(
            tensorflow_computation.tf_computation(
                lambda: tf.cast(tf.timestamp(), tf.int64)), placements.SERVER)
        client_histogram = intrinsics.federated_map(client_work,
                                                    federated_client_data)

        server_output = intrinsics.federated_zip(
            ServerOutput(
                process.next(process.initialize(), client_histogram).result,
                round_timestamp))
        return server_output
Exemplo n.º 29
0
 def _compute_measurements(self, upper_bound, lower_bound, value_max,
                           value_min):
     """Creates measurements to be reported. All values are summed securely."""
     is_max_clipped = intrinsics.federated_map(
         tensorflow_computation.tf_computation(
             lambda bound, value: tf.cast(bound < value, COUNT_TF_TYPE)),
         (intrinsics.federated_broadcast(upper_bound), value_max))
     max_clipped_count = intrinsics.federated_secure_sum_bitwidth(
         is_max_clipped, bitwidth=1)
     is_min_clipped = intrinsics.federated_map(
         tensorflow_computation.tf_computation(
             lambda bound, value: tf.cast(bound > value, COUNT_TF_TYPE)),
         (intrinsics.federated_broadcast(lower_bound), value_min))
     min_clipped_count = intrinsics.federated_secure_sum_bitwidth(
         is_min_clipped, bitwidth=1)
     measurements = collections.OrderedDict(
         secure_upper_clipped_count=max_clipped_count,
         secure_lower_clipped_count=min_clipped_count,
         secure_upper_threshold=upper_bound,
         secure_lower_threshold=lower_bound)
     return intrinsics.federated_zip(measurements)
Exemplo n.º 30
0
    def test_takes_namedtuple_typed(self):
        MyType = collections.namedtuple('MyType', ['x', 'y'])  # pylint: disable=invalid-name

        @tf.function
        def foo(x):
            self.assertIsInstance(x, MyType)
            return x.x + x.y

        foo = tensorflow_computation.tf_computation(foo,
                                                    MyType(tf.int32, tf.int32))
        self.assertEqual(foo.type_signature.compact_representation(),
                         '(<x=int32,y=int32> -> int32)')